From be11967857847fa6a6367e4b8bde56394a4f8fd2 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:53:59 -0800 Subject: [PATCH] Working simple python thread metrics --- .../src/bindings/transformers/transformers.py | 68 +++++++++++++++++-- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 7d46d0636..eed9cede7 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -3,6 +3,8 @@ import shutil import time import queue +import sys +import json import datasets from InstructorEmbedding import INSTRUCTOR @@ -40,7 +42,7 @@ TrainingArguments, Trainer, ) -from threading import Thread +import threading __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} @@ -62,6 +64,26 @@ } +class WorkerThreads: + def __init__(self): + self.worker_threads = {} + + def delete_thread(self, id): + del self.worker_threads[id] + + def update_thread(self, id, value): + self.worker_threads[id] = value + + def get_thread(self, id): + if id in self.worker_threads: + return self.worker_threads[id] + else: + return None + + +worker_threads = WorkerThreads() + + class PgMLException(Exception): pass @@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): self.token_cache = [] self.text_index_cache = [] + def set_worker_thread_id(self, id): + self.worker_thread_id = id + + def get_worker_thread_id(self): + return self.worker_thread_id + def put(self, values): if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False @@ -149,6 +177,22 @@ def __next__(self): return value +def streaming_worker(worker_threads, model, **kwargs): + thread_id = threading.get_native_id() + try: + worker_threads.update_thread( + thread_id, json.dumps({"model": model.name_or_path}) + ) + except: + worker_threads.update_thread(thread_id, "Error setting data") + try: + model.generate(**kwargs) + except BaseException as error: + print(f"Error in streaming_worker: {error}", file=sys.stderr) + finally: + worker_threads.delete_thread(thread_id) + + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -185,7 +229,7 @@ def do_work(): self.q.put(x) self.done = True - thread = Thread(target=do_work) + thread = threading.Thread(target=do_work) thread.start() def __iter__(self): @@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs): input, add_generation_prompt=True, tokenize=False ) input = self.tokenizer(input, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(input, streamer=streamer, **kwargs) + generation_kwargs = dict( + input, + worker_threads=worker_threads, + model=self.model, + streamer=streamer, + **kwargs, + ) else: streamer = TextIteratorStreamer( self.tokenizer, @@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs): input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) - generation_kwargs = dict(input, streamer=streamer, **kwargs) - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + generation_kwargs = dict( + input, + worker_threads=worker_threads, + model=self.model, + streamer=streamer, + **kwargs, + ) + # thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs) thread.start() + streamer.set_worker_thread_id(thread.native_id) return streamer def __call__(self, inputs, **kwargs):