diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bd8e0be --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black"] + - repo: https://github.com/psf/black + rev: 22.6.0 + hooks: + - id: black diff --git a/evaluation/text2code.py b/evaluation/text2code.py index 58dd789..d48c3ea 100644 --- a/evaluation/text2code.py +++ b/evaluation/text2code.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Literal, TypedDict, cast + from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl from tqdm.auto import tqdm from transformers import HfArgumentParser @@ -10,6 +11,7 @@ from star_align.prompt_template import SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE from star_align.utils import chunked + class Text2CodeProblem(TypedDict): id: str instruction: str @@ -25,6 +27,7 @@ def get_humaneval_raw_problems() -> list[dict]: problems = get_human_eval_plus() return list(problems.values()) + def map_mbpp_problem(p: dict) -> Text2CodeProblem: id = p["task_id"] prompt = p["prompt"] diff --git a/seed_gathering/benchmark_data.py b/seed_gathering/benchmark_data.py index bb494d6..1e2652e 100644 --- a/seed_gathering/benchmark_data.py +++ b/seed_gathering/benchmark_data.py @@ -1,16 +1,19 @@ """data to filter out of the dataset""" -import json import itertools +import json from pathlib import Path from datasets import load_dataset - TEST_IDS = list(range(11, 511)) # HumanEval solutions that are considered simple/generic enough to be kept in the training dataset HUMAN_EVAL_STRINGS_OK = [ - 'return x + y', 'return len(string)', 'return n**2', 'return ''.join(strings)'] + "return x + y", + "return len(string)", + "return n**2", + "return " ".join(strings)", +] DS_1000_PATH = Path("/data/ds-1000/ds1000_data/") @@ -35,7 +38,7 @@ def load_ds_1000(): def load_mbpp(): - dataset = load_dataset("mbpp", "sanitized", split="train") + dataset = load_dataset("mbpp", "sanitized", split="train") return dataset @@ -57,16 +60,16 @@ def extract_docstring(prompt: str) -> str: return prompt.split('"""')[3].strip() else: raise ValueError() - elif '\'\'\'' in prompt: - assert prompt.count('\'\'\'') == 2 - return prompt.split('\'\'\'')[1].strip() + elif "'''" in prompt: + assert prompt.count("'''") == 2 + return prompt.split("'''")[1].strip() else: raise ValueError() def human_eval_docstrings(): ds = load_dataset("openai_humaneval", split="test") - docstrings = [extract_docstring(v['prompt']) for v in ds] + docstrings = [extract_docstring(v["prompt"]) for v in ds] return docstrings @@ -75,17 +78,32 @@ def apps_solutions(): Solutions column contains a list of strings """ ds = load_dataset("codeparrot/apps", split="test") - solutions = [sample["solutions"] - for sample in ds if len(sample["solutions"]) > 0] - res = itertools.chain.from_iterable( - json.loads(sample) for sample in solutions) + solutions = [sample["solutions"] for sample in ds if len(sample["solutions"]) > 0] + res = itertools.chain.from_iterable(json.loads(sample) for sample in solutions) return list(res) def multipl_e_docstrings(): languages = [ - "cpp", "cs", "d", "go", "java", "jl", "js", "lua", "php", "pl", "py", "r", - "rb", "rkt", "rs", "scala", "sh", "swift", "ts" + "cpp", + "cs", + "d", + "go", + "java", + "jl", + "js", + "lua", + "php", + "pl", + "py", + "r", + "rb", + "rkt", + "rs", + "scala", + "sh", + "swift", + "ts", ] # languages = ["py", "java", "js"] src_datas = ["humaneval", "mbpp"] @@ -97,7 +115,8 @@ def multipl_e_docstrings(): if src_data == "mbpp" and variation == "-remove": continue ds = load_dataset( - "nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test") + "nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test" + ) data += [sample["prompt"].strip() for sample in ds] return data @@ -115,7 +134,10 @@ def filter_out(): "mbpp_solutions": mbpp_solutions(), "human_eval_docstrings": human_eval_docstrings(), "human_eval_solutions": [ - s for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") + s + for s in load_dataset_column( + "openai_humaneval", "canonical_solution", "test" + ) if s not in HUMAN_EVAL_STRINGS_OK ], } diff --git a/seed_gathering/filter_dataset.py b/seed_gathering/filter_dataset.py index aff5a8f..a1ca6c7 100644 --- a/seed_gathering/filter_dataset.py +++ b/seed_gathering/filter_dataset.py @@ -1,29 +1,31 @@ -import datasets +import argparse import os -from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser +import random + import benchmark_data -from tqdm import tqdm +import datasets import torch -import argparse +from tqdm import tqdm +from tree_sitter_parser import LANGUAGE, does_have_return, global_parser, make_parser from vllm import LLM, SamplingParams -import random parser = argparse.ArgumentParser() -parser.add_argument('--dataset', type=str, required=True) -parser.add_argument('--model', type=str, - default="bigcode/starcoder2-15b") -parser.add_argument('--batch-size', type=int, default=512) -parser.add_argument('--sample-size', type=int, default=None) -parser.add_argument('--num-gpus', type=int, default=1) -parser.add_argument('--content_col', type=str, default="content") -parser.add_argument('--push', type=str, required=True) +parser.add_argument("--dataset", type=str, required=True) +parser.add_argument("--model", type=str, default="bigcode/starcoder2-15b") +parser.add_argument("--batch-size", type=int, default=512) +parser.add_argument("--sample-size", type=int, default=None) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument("--content_col", type=str, default="content") +parser.add_argument("--push", type=str, required=True) args = parser.parse_args() random.seed(42) -FN_BLOCK_QUERY = LANGUAGE.query(""" +FN_BLOCK_QUERY = LANGUAGE.query( + """ (function_definition body: (block) @fn-block) -""") +""" +) def template_few_shot(code, answer, rationale): @@ -95,7 +97,8 @@ def numeric_column_to_int(series): "Yes", "The docstring does seem to match the implementation! The function loops through the columns of a df and coerces it as explained.", ), - ('''def __trans_df_into_dict(data): + ( + '''def __trans_df_into_dict(data): """Converte DataFrame to dictionary. Args: @@ -113,9 +116,9 @@ def numeric_column_to_int(series): fname_dict = dict(zip(data["en_name_f"], data["jp_name_f"])) lname_dict = dict(zip(data["en_name_l"], data["jp_name_l"])) return fullname_dict, fname_dict, lname_dict''', - "No", - "The function__trans_df_into_dict does indeed convert a dataframe into a dictionary, however, it converts various columns that were not described in the docstring.\nFor instance, nowhere in the docstring it mentions handling japanese characters or the name of the column.", - ), + "No", + "The function__trans_df_into_dict does indeed convert a dataframe into a dictionary, however, it converts various columns that were not described in the docstring.\nFor instance, nowhere in the docstring it mentions handling japanese characters or the name of the column.", + ), ( '''def inchesToMeters(inches): """Convert inches to meters.""" @@ -123,7 +126,8 @@ def numeric_column_to_int(series): "Yes", "inchesToMeters is a very simple function, the doccstring explains concisely its purpose, which is of converting inches to meters.", ), - ('''def square_crop(im, target_size=None): + ( + '''def square_crop(im, target_size=None): """ Crop image to `target_size`. If that's None the image is squared to the smallest size """ @@ -137,10 +141,11 @@ def numeric_column_to_int(series): dy = (h - target_size) / 2 return im.crop((dx, dy, dx + target_size, dy + target_size))''', - "Yes", - "Following the standard description for docstrings for functions and methods, the square_crop function description tells exactly what the function does." - ), - ('''def _setup_motifs_files(args): + "Yes", + "Following the standard description for docstrings for functions and methods, the square_crop function description tells exactly what the function does.", + ), + ( + '''def _setup_motifs_files(args): """convenience fn, make sure setup is same across multiplicity/orientation/spacing workflows """ @@ -156,10 +161,11 @@ def numeric_column_to_int(series): args.inputs["inference"][args.cluster]["scanmotifs_late_dir"]) return motifs_files''', - "No", - "The docstring for _setup_motifs_files just says this is a convenience function. There is definitely not enough information to re-implement this function from the docstring alone.", - ), - ('''def trip(u, v): + "No", + "The docstring for _setup_motifs_files just says this is a convenience function. There is definitely not enough information to re-implement this function from the docstring alone.", + ), + ( + '''def trip(u, v): """ Returns the scalar triple product of vectors u and v and z axis. The convention is z dot (u cross v). Dotting with the z axis simplifies @@ -173,9 +179,9 @@ def numeric_column_to_int(series): Essentially trip is the z component of the cross product of u x v """ return (u[0] * v[1] - u[1] * v[0])''', - "Yes", - "The docstring for the trip function is very detailed and describes the function's purpose and the mathematical formula used to calculate the scalar triple product.", - ) + "Yes", + "The docstring for the trip function is very detailed and describes the function's purpose and the mathematical formula used to calculate the scalar triple product.", + ), ] @@ -227,22 +233,31 @@ def chunkify(lst, n): print(f"Loaded {len(dataset)} examples. Running pre-filtering...") BAD_WORDS = ["todo", "fixme", "bug"] -BAD_IMPORTS = ["argparse", "os", "subprocess", "sys", "setuptools", - "distutils", "matplotlib", "seaborn"] -BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + \ - [f"from {b}" for b in BAD_IMPORTS] +BAD_IMPORTS = [ + "argparse", + "os", + "subprocess", + "sys", + "setuptools", + "distutils", + "matplotlib", + "seaborn", +] +BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + [f"from {b}" for b in BAD_IMPORTS] BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS bench_filter = benchmark_data.filter_out() -all_bench = bench_filter["human_eval_docstrings"] + \ - bench_filter["human_eval_solutions"] + \ - bench_filter["mbpp_docstrings"] + \ - bench_filter["mbpp_solutions"] +all_bench = ( + bench_filter["human_eval_docstrings"] + + bench_filter["human_eval_solutions"] + + bench_filter["mbpp_docstrings"] + + bench_filter["mbpp_solutions"] +) def pre_filtering(ex): code = ex[args.content_col] - code_bytes = code.encode('utf-8') + code_bytes = code.encode("utf-8") # filter out bad substrings lower = code.lower() @@ -277,11 +292,14 @@ def pre_filtering(ex): # get the docstring, filter if not a docstring exp = block.children[0] - if not exp.type == 'expression_statement' and not exp.children[0].type == 'string': + if ( + not exp.type == "expression_statement" + and not exp.children[0].type == "string" + ): return False docstring = exp.children[0] - docstring_text = docstring.text.decode('utf-8') + docstring_text = docstring.text.decode("utf-8") if not docstring_text.startswith('"""') and not docstring_text.endswith('"""'): return False except Exception as e: @@ -294,8 +312,12 @@ def pre_filtering(ex): threads = os.cpu_count() - 1 # type: ignore dataset = dataset.filter(pre_filtering, num_proc=threads) -model = LLM(args.model, dtype=auto_dtype(), - gpu_memory_utilization=0.95, tensor_parallel_size=args.num_gpus) +model = LLM( + args.model, + dtype=auto_dtype(), + gpu_memory_utilization=0.95, + tensor_parallel_size=args.num_gpus, +) tokenizer = model.get_tokenizer() if args.sample_size is not None: @@ -309,31 +331,34 @@ def pre_filtering(ex): def unindent(s): lines = s.splitlines() non_blank_lines = [line for line in lines if line.strip()] - min_indent = min(len(line) - len(line.lstrip()) - for line in non_blank_lines) if non_blank_lines else 0 - unindented_lines = [line[min_indent:] if len( - line) >= min_indent else line for line in lines] - return '\n'.join(unindented_lines) + min_indent = ( + min(len(line) - len(line.lstrip()) for line in non_blank_lines) + if non_blank_lines + else 0 + ) + unindented_lines = [ + line[min_indent:] if len(line) >= min_indent else line for line in lines + ] + return "\n".join(unindented_lines) def py_extract_docstring(code): first_doc = code.find('"""') assert first_doc != -1 first_doc = first_doc + 3 - second_doc = code[first_doc+1:].find('"""') + second_doc = code[first_doc + 1 :].find('"""') assert second_doc != -1 second_doc = second_doc + first_doc + 1 doc = code[first_doc:second_doc] doc = unindent(doc).strip() - code = code[:first_doc-3] + code[second_doc+3:] + code = code[: first_doc - 3] + code[second_doc + 3 :] return doc, code # this is such a hack, but it works dummy = 'def dummy(): \n """\n """\n pass' dummy_prompt = prompt_fmt(dummy) -few_shot_toks = len(tokenizer.encode( - dummy_prompt)) - len(tokenizer.encode(dummy)) +few_shot_toks = len(tokenizer.encode(dummy_prompt)) - len(tokenizer.encode(dummy)) print(f"Few-shot prompt has {few_shot_toks} tokens") prompts = [] for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"): @@ -349,8 +374,9 @@ def py_extract_docstring(code): responses = [] for chunk in tqdm(chunkify(prompts, args.batch_size), desc="Generating responses"): - outs = model.generate(chunk, SamplingParams( - temperature=0.0, stop="\n", max_tokens=5)) + outs = model.generate( + chunk, SamplingParams(temperature=0.0, stop="\n", max_tokens=5) + ) contents = [o.outputs[0].text for o in outs] for c in contents: yes_count = c.lower().count("yes") @@ -365,6 +391,8 @@ def py_extract_docstring(code): new_ds = dataset.filter( # horrible hack! - lambda ex, i: responses[i] and "def dummy()" not in ex[args.content_col], with_indices=True) + lambda ex, i: responses[i] and "def dummy()" not in ex[args.content_col], + with_indices=True, +) print(f"Filtered {len(dataset) - len(new_ds)} examples") new_ds.push_to_hub(args.push, private=True) diff --git a/seed_gathering/generate_from_the_stack.py b/seed_gathering/generate_from_the_stack.py index e7409ef..34111eb 100644 --- a/seed_gathering/generate_from_the_stack.py +++ b/seed_gathering/generate_from_the_stack.py @@ -1,11 +1,12 @@ -from tree_sitter_parser import LANGUAGE, make_parser, node_to_string -import datasets import os import signal from multiprocessing import Pool +import datasets +from tree_sitter_parser import LANGUAGE, make_parser, node_to_string -TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query(""" +TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query( + """ ( (function_definition name: (identifier) @@ -18,7 +19,8 @@ (#eq? @docstring.start "\\\"\\\"\\\"") (#eq? @docstring.end "\\\"\\\"\\\"") ) -""") +""" +) def get_fns_with_docstrings(src, tree): @@ -87,16 +89,22 @@ def main(args): print(f"Processing chunk {i // CHUNK_SIZE}") # divide the chunk into NUM_WORKERS chunks subchunk_size = len(chunk) // args.num_workers - subchunks = [chunk[i:i + subchunk_size] - for i in range(0, len(chunk), subchunk_size)] + subchunks = [ + chunk[i : i + subchunk_size] + for i in range(0, len(chunk), subchunk_size) + ] new_funs_iter = p.imap( - process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)]) + process_chunk, + [(i, subchunk) for i, subchunk in enumerate(subchunks)], + ) print("Getting new functions") len_before = len(funs) while True: try: + def timeout_handler(_, __): raise KeyboardInterrupt # it's fineeeeeee + signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(60) funs.update(next(new_funs_iter)) @@ -117,7 +125,8 @@ def timeout_handler(_, __): PARSERS = [make_parser() for _ in range(args.num_workers)] print( - f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions") + f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions" + ) chunk = [] except Exception as e: @@ -129,10 +138,7 @@ def timeout_handler(_, __): p.close() - new_ds_dict = { - "content": list(funs), - "id": list(range(len(funs))) - } + new_ds_dict = {"content": list(funs), "id": list(range(len(funs)))} new_ds = datasets.Dataset.from_dict(new_ds_dict) new_ds.push_to_hub(args.push, private=True) @@ -140,10 +146,10 @@ def timeout_handler(_, __): if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() parser.add_argument("--num_workers", type=int, default=os.cpu_count()) - parser.add_argument("--dataset", type=str, - default="bigcode/the-stack-dedup") + parser.add_argument("--dataset", type=str, default="bigcode/the-stack-dedup") parser.add_argument("--data_dir", type=str, default="data/python") parser.add_argument("--push", type=str, required=True) args = parser.parse_args() diff --git a/seed_gathering/high_quality_subset.py b/seed_gathering/high_quality_subset.py index 94df17e..a86d24c 100644 --- a/seed_gathering/high_quality_subset.py +++ b/seed_gathering/high_quality_subset.py @@ -1,18 +1,20 @@ -import datasets -import subprocess -import tempfile -import signal +import argparse import hashlib import os -import argparse -from typing import List, Dict -from tqdm import tqdm +import signal +import subprocess +import tempfile +from typing import Dict, List +import datasets +from tqdm import tqdm from tree_sitter_parser import LANGUAGE, global_parser -RETURN_QUERY = LANGUAGE.query(""" +RETURN_QUERY = LANGUAGE.query( + """ (return_statement) @return -""") +""" +) def does_have_return(src): @@ -100,8 +102,10 @@ def infer_imports(code: str) -> str: import autoimport try: + def handler(signum, frame): raise Exception("Timeout") + signal.signal(signal.SIGALRM, handler) signal.alarm(10) inferred = autoimport.fix_code(code) @@ -114,17 +118,17 @@ def handler(signum, frame): def main(args): - ds = datasets.load_dataset(args.dataset, - data_dir="data", split="train") + ds = datasets.load_dataset(args.dataset, data_dir="data", split="train") print("Filtering to only functions with return statements") - ds = ds.filter(lambda ex: does_have_return( - ex["content"]), num_proc=os.cpu_count()) + ds = ds.filter(lambda ex: does_have_return(ex["content"]), num_proc=os.cpu_count()) if args.infer_imports: print("Inferring imports for functions") - ds = ds.map(lambda ex: {"content": infer_imports( - ex["content"])}, num_proc=os.cpu_count()) + ds = ds.map( + lambda ex: {"content": infer_imports(ex["content"])}, + num_proc=os.cpu_count(), + ) batch = [] max_i = len(ds) - 1 @@ -162,14 +166,20 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, - help="Points to dataset of python functions with docstrings. Columns: 'content'", - required=True) parser.add_argument( - "--push", type=str, required=True, help="Push to this dataset to which repo") + "--dataset", + type=str, + help="Points to dataset of python functions with docstrings. Columns: 'content'", + required=True, + ) + parser.add_argument( + "--push", type=str, required=True, help="Push to this dataset to which repo" + ) parser.add_argument( - "--infer-imports", action="store_true", help="Infer imports for functions") + "--infer-imports", action="store_true", help="Infer imports for functions" + ) parser.add_argument( - "--batch-size", type=int, default=250, help="Batch size for typechecking") + "--batch-size", type=int, default=250, help="Batch size for typechecking" + ) args = parser.parse_args() main(args) diff --git a/seed_gathering/tree_sitter_parser.py b/seed_gathering/tree_sitter_parser.py index 3021064..fca8623 100644 --- a/seed_gathering/tree_sitter_parser.py +++ b/seed_gathering/tree_sitter_parser.py @@ -1,17 +1,14 @@ from tree_sitter import Language, Parser -Language.build_library( - 'build/lang.so', - [ - './tree-sitter-python' - ] -) -LANGUAGE = Language('build/lang.so', 'python') +Language.build_library("build/lang.so", ["./tree-sitter-python"]) +LANGUAGE = Language("build/lang.so", "python") -QUERY = LANGUAGE.query(""" +QUERY = LANGUAGE.query( + """ (function_definition name: (identifier) @fn-name) -""") +""" +) global_parser = Parser() @@ -29,7 +26,7 @@ def get_fn_name(code, parser=global_parser): def node_to_string(src: bytes, node): - return src[node.start_byte:node.end_byte].decode("utf8") + return src[node.start_byte : node.end_byte].decode("utf8") def make_parser(): @@ -38,9 +35,11 @@ def make_parser(): return _parser -RETURN_QUERY = LANGUAGE.query(""" +RETURN_QUERY = LANGUAGE.query( + """ (return_statement) @return -""") +""" +) def does_have_return(src, parser=global_parser): diff --git a/src/star_align/llm_wrapper.py b/src/star_align/llm_wrapper.py index 669d3e0..1f8b8b8 100644 --- a/src/star_align/llm_wrapper.py +++ b/src/star_align/llm_wrapper.py @@ -330,6 +330,7 @@ def complete( decoded_outputs=output_strings, ) + class SupportedModelKeys(Enum): # StarCoder-based models STARCODER_15B = "bigcode/starcoder" diff --git a/src/star_align/prompt_template.py b/src/star_align/prompt_template.py index 5ad5e72..4d3b41f 100644 --- a/src/star_align/prompt_template.py +++ b/src/star_align/prompt_template.py @@ -18,4 +18,4 @@ {%- endif %} {%- endif %} {%- endfor %} -{{'### Response\n'}}""" \ No newline at end of file +{{'### Response\n'}}""" diff --git a/src/star_align/self_ossinstruct.py b/src/star_align/self_ossinstruct.py index 9de09eb..6a34e01 100644 --- a/src/star_align/self_ossinstruct.py +++ b/src/star_align/self_ossinstruct.py @@ -249,9 +249,7 @@ def get_ossinstruct_fewshots() -> Fewshot: system_prompt = splits[0].strip() # "I->R", "E->S", "I->I", "PI->PI", "S->C" sys_pattern = r"### System: I->R|### System: C->I|### System: S->C" - _, i_r, c_i, s_c = list( - map(str.strip, re.split(sys_pattern, system_prompt)) - ) + _, i_r, c_i, s_c = list(map(str.strip, re.split(sys_pattern, system_prompt))) # system_prompt = re.split(r"### System: Instruction", system_prompt)[1] # instruction_system_prompt, response_system_prompt = system_prompt.split( # "### System: Response" diff --git a/src/star_align/train.py b/src/star_align/train.py index 31405de..e2e466d 100644 --- a/src/star_align/train.py +++ b/src/star_align/train.py @@ -13,7 +13,8 @@ get_model_context, pad_sequences, ) -from star_align.prompt_template import CHAT_TEMPLATE, SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE +from star_align.prompt_template import CHAT_TEMPLATE +from star_align.prompt_template import SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE from star_align.utils import N_CORES @@ -29,6 +30,7 @@ class ModelArguments: # Ignored index in CrossEntropyLoss IGNORED_INDEX = -100 + def map_dataset( examples: dict[str, list[str]], args: "Args",