Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 22.6.0
hooks:
- id: black
3 changes: 3 additions & 0 deletions evaluation/text2code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, TypedDict, cast

from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl
from tqdm.auto import tqdm
from transformers import HfArgumentParser
Expand All @@ -10,6 +11,7 @@
from star_align.prompt_template import SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE
from star_align.utils import chunked


class Text2CodeProblem(TypedDict):
id: str
instruction: str
Expand All @@ -25,6 +27,7 @@ def get_humaneval_raw_problems() -> list[dict]:
problems = get_human_eval_plus()
return list(problems.values())


def map_mbpp_problem(p: dict) -> Text2CodeProblem:
id = p["task_id"]
prompt = p["prompt"]
Expand Down
54 changes: 38 additions & 16 deletions seed_gathering/benchmark_data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""data to filter out of the dataset"""
import json
import itertools
import json
from pathlib import Path

from datasets import load_dataset


TEST_IDS = list(range(11, 511))

# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset
HUMAN_EVAL_STRINGS_OK = [
'return x + y', 'return len(string)', 'return n**2', 'return ''.join(strings)']
"return x + y",
"return len(string)",
"return n**2",
"return " ".join(strings)",
]

DS_1000_PATH = Path("/data/ds-1000/ds1000_data/")

Expand All @@ -35,7 +38,7 @@ def load_ds_1000():


def load_mbpp():
dataset = load_dataset("mbpp", "sanitized", split="train")
dataset = load_dataset("mbpp", "sanitized", split="train")
return dataset


Expand All @@ -57,16 +60,16 @@ def extract_docstring(prompt: str) -> str:
return prompt.split('"""')[3].strip()
else:
raise ValueError()
elif '\'\'\'' in prompt:
assert prompt.count('\'\'\'') == 2
return prompt.split('\'\'\'')[1].strip()
elif "'''" in prompt:
assert prompt.count("'''") == 2
return prompt.split("'''")[1].strip()
else:
raise ValueError()


def human_eval_docstrings():
ds = load_dataset("openai_humaneval", split="test")
docstrings = [extract_docstring(v['prompt']) for v in ds]
docstrings = [extract_docstring(v["prompt"]) for v in ds]
return docstrings


Expand All @@ -75,17 +78,32 @@ def apps_solutions():
Solutions column contains a list of strings
"""
ds = load_dataset("codeparrot/apps", split="test")
solutions = [sample["solutions"]
for sample in ds if len(sample["solutions"]) > 0]
res = itertools.chain.from_iterable(
json.loads(sample) for sample in solutions)
solutions = [sample["solutions"] for sample in ds if len(sample["solutions"]) > 0]
res = itertools.chain.from_iterable(json.loads(sample) for sample in solutions)
return list(res)


def multipl_e_docstrings():
languages = [
"cpp", "cs", "d", "go", "java", "jl", "js", "lua", "php", "pl", "py", "r",
"rb", "rkt", "rs", "scala", "sh", "swift", "ts"
"cpp",
"cs",
"d",
"go",
"java",
"jl",
"js",
"lua",
"php",
"pl",
"py",
"r",
"rb",
"rkt",
"rs",
"scala",
"sh",
"swift",
"ts",
]
# languages = ["py", "java", "js"]
src_datas = ["humaneval", "mbpp"]
Expand All @@ -97,7 +115,8 @@ def multipl_e_docstrings():
if src_data == "mbpp" and variation == "-remove":
continue
ds = load_dataset(
"nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test")
"nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test"
)
data += [sample["prompt"].strip() for sample in ds]
return data

Expand All @@ -115,7 +134,10 @@ def filter_out():
"mbpp_solutions": mbpp_solutions(),
"human_eval_docstrings": human_eval_docstrings(),
"human_eval_solutions": [
s for s in load_dataset_column("openai_humaneval", "canonical_solution", "test")
s
for s in load_dataset_column(
"openai_humaneval", "canonical_solution", "test"
)
if s not in HUMAN_EVAL_STRINGS_OK
],
}
Expand Down
Loading