From 153ef756bad2f76ee8a3b7d173802f9df8b8a4f1 Mon Sep 17 00:00:00 2001 From: Sebastian Bordt Date: Fri, 12 Apr 2024 16:34:41 +0200 Subject: [PATCH] contiguous chat completion --- tabmemcheck/chat_completion.py | 105 ++++++++++++++++++++++++++++++--- tabmemcheck/llm.py | 19 ++++-- 2 files changed, 110 insertions(+), 14 deletions(-) diff --git a/tabmemcheck/chat_completion.py b/tabmemcheck/chat_completion.py index ea54840..5e13f87 100644 --- a/tabmemcheck/chat_completion.py +++ b/tabmemcheck/chat_completion.py @@ -215,26 +215,113 @@ def row_completion( #################################################################################### +def build_contiguous_query( + text: str, prefix_length: int, suffix_length: int, few_shot: int, rng +): + query_length = (prefix_length + suffix_length) * (1 + few_shot) + # the length of the string must be at least (prefix_length + suffix_length) * (1 + few_shot) + assert ( + len(text) >= query_length + ), "The provided string is too short for the specified prefix and suffix lengths." + # choose a random sub-string of length query_length + idx = rng.integers(low=0, high=len(text) - query_length) + s_query = text[idx : idx + query_length] + # construct few-shot examples + few_shot_examples = [] + for i_fs in range(few_shot): + offset = (prefix_length + suffix_length) * i_fs + few_shot_examples.append( + ( + [s_query[offset : offset + prefix_length]], + [ + s_query[ + offset + prefix_length : offset + prefix_length + suffix_length + ] + ], + ) + ) + # prefix and suffix + prefix = s_query[ + query_length - prefix_length - suffix_length : query_length - suffix_length + ] + suffix = s_query[query_length - suffix_length :] + return few_shot_examples, prefix, suffix + + def chat_completion( llm: LLM_Interface, - strings: list[str], - system_prompt: str = "You are a helpful assistant that complets the user's input.", - few_shot=5, + strings: str | list[str], + system_prompt: str = "You are a helpful assistant.", + prefix_length: int = None, + suffix_length: int = None, + few_shot=5, # integer, or list [str, ..., str] or [[str,..,str], ..., [str,..,str]] + contiguous=False, num_queries=10, print_levenshtein=False, out_file=None, rng=None, ): - """Basic completion with a chat model and a list of strings.""" - # randomly split the strings into prefixes and suffixes, then use prefix_suffix_chat_completion + """General-purpose chat completion.""" if rng is None: rng = np.random.default_rng() + if isinstance(strings, str): + strings = [strings] + + def prefix_suffix_split(s): + if prefix_length is not None: + return s[:prefix_length], s[prefix_length:] + else: # randomly split the string into prefix and suffix + idx = rng.integers(low=int(len(s) / 3), high=int(2 * len(s) / 3)) + return s[:idx], s[idx:] + + if contiguous: + # few-shot has to be an integer + assert isinstance( + few_shot, int + ), "For contiguous chat completion, few_shot must be an integer." + # both prefix_length and suffix_length have to be specified + assert ( + prefix_length is not None and suffix_length is not None + ), "For contiguous chat completion, both prefix_length and suffix_length have to be specified." + prefixes, suffixes, responses = [], [], [] + for _ in range(num_queries): + # select a random string and build the query + few_shot_examples, prefix, suffix = build_contiguous_query( + rng.choice(strings), prefix_length, suffix_length, few_shot, rng + ) + # send query + prefix, suffix, response = prefix_suffix_chat_completion( + llm, + [prefix], + [suffix], + system_prompt, + few_shot=few_shot_examples, + num_queries=1, + print_levenshtein=print_levenshtein, + out_file=out_file, + rng=rng, + ) + prefixes.append(prefix) + suffixes.append(suffix) + responses.append(response) + return prefixes, suffixes, responses + + # non-contiguous prefixes = [] suffixes = [] - for s in strings: - idx = rng.integers(low=int(len(s) / 3), high=int(2 * len(s) / 3)) - prefixes.append(s[:idx]) - suffixes.append(s[idx:]) + for s_query in strings: # fixed prefix length specified by the user + prefix, suffix = prefix_suffix_split(s_query) + prefixes.append(prefix) + suffixes.append(suffix) + # few shot list + if isinstance(few_shot, list): + if len(few_shot) > 0: + if isinstance(few_shot[0], list): # list of lists + few_shot = [[prefix_suffix_split(s) for s in fs] for fs in few_shot] + few_shot = [([x[0] for x in fs], [x[1] for x in fs]) for fs in few_shot] + else: # list of strings + few_shot = [prefix_suffix_split(s) for s in few_shot] + few_shot = [([fs[0]], [fs[1]]) for fs in few_shot] return prefix_suffix_chat_completion( llm, prefixes, diff --git a/tabmemcheck/llm.py b/tabmemcheck/llm.py index 2191633..a3ff05c 100644 --- a/tabmemcheck/llm.py +++ b/tabmemcheck/llm.py @@ -108,13 +108,15 @@ class OpenAILLM(LLM_Interface): client: OpenAI = None model: str = None - def __init__(self, client, model=None): + def __init__(self, client, model, chat_mode=None): super().__init__() self.client = client self.model = model # auto-detect chat models if "gpt-3.5" in model or "gpt-4" in model: self.chat_mode = True + if chat_mode is not None: + self.chat_mode = chat_mode @retry( retry=retry_if_not_exception_type(openai.BadRequestError), @@ -153,17 +155,20 @@ def chat_completion(self, messages, temperature, max_tokens): ) # we return the completion string or "" if there is an invalid response/query try: - response = response.choices[0].message.content + response_content = response.choices[0].message.content except: print(f"Invalid response {response}") - response = "" - return response + response_content = "" + if response_content is None: + print(f"Invalid response {response}") + response_content = "" + return response_content def __repr__(self) -> str: return f"{self.model}" -def openai_setup(model: str, azure: bool = False): +def openai_setup(model: str, azure: bool = False, *args, **kwargs): """Setup an OpenAI language model. :param model: The name of the model (e.g. "gpt-3.5-turbo-0613"). @@ -197,6 +202,8 @@ def openai_setup(model: str, azure: bool = False): if "AZURE_OPENAI_VERSION" in os.environ else None ), + *args, + **kwargs, ) else: # openai api client = OpenAI( @@ -206,6 +213,8 @@ def openai_setup(model: str, azure: bool = False): organization=( os.environ["OPENAI_API_ORG"] if "OPENAI_API_ORG" in os.environ else None ), + *args, + **kwargs, ) # the llm