From 3dbb3fd3f6b29858f550e42b3973722fe908f8d2 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 28 Mar 2023 04:03:57 -0400 Subject: [PATCH] Add support for stream parameter. Closes #1 --- examples/high_level_api_streaming.py | 20 ++++ llama_cpp/llama.py | 142 ++++++++++++++++++++------- 2 files changed, 129 insertions(+), 33 deletions(-) create mode 100644 examples/high_level_api_streaming.py diff --git a/examples/high_level_api_streaming.py b/examples/high_level_api_streaming.py new file mode 100644 index 0000000..d744090 --- /dev/null +++ b/examples/high_level_api_streaming.py @@ -0,0 +1,20 @@ +import json +import argparse + +from llama_cpp import Llama + +parser = argparse.ArgumentParser() +parser.add_argument("-m", "--model", type=str, default=".//models/...") +args = parser.parse_args() + +llm = Llama(model_path=args.model) + +stream = llm( + "Question: What are the names of the planets in the solar system? Answer: ", + max_tokens=48, + stop=["Q:", "\n"], + stream=True, +) + +for output in stream: + print(json.dumps(output, indent=2)) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index cd16bca..268c27d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -88,7 +88,7 @@ class Llama: True, ) if n_tokens < 0: - raise RuntimeError(f"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}") + raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') return list(tokens[:n_tokens]) def detokenize(self, tokens: List[int]) -> bytes: @@ -105,7 +105,6 @@ class Llama: output += llama_cpp.llama_token_to_str(self.ctx, token) return output - def _eval(self, tokens: List[int], n_past): rc = llama_cpp.llama_eval( self.ctx, @@ -137,12 +136,12 @@ class Llama: top_p=top_p, top_k=top_k, temp=temp, - repeat_penalty=repeat_penalty + repeat_penalty=repeat_penalty, ) yield token self._eval([token], len(past_tokens) + i) - def __call__( + def _call( self, prompt: str, suffix: Optional[str] = None, @@ -154,34 +153,11 @@ class Llama: stop: List[str] = [], repeat_penalty: float = 1.1, top_k: int = 40, + stream: bool = False, ): - """Generate text from a prompt. - - Args: - prompt: The prompt to generate text from. - suffix: A suffix to append to the generated text. If None, no suffix is appended. - max_tokens: The maximum number of tokens to generate. - temperature: The temperature to use for sampling. - top_p: The top-p value to use for sampling. - logprobs: The number of logprobs to return. If None, no logprobs are returned. - echo: Whether to echo the prompt. - stop: A list of strings to stop generation when encountered. - repeat_penalty: The penalty to apply to repeated tokens. - top_k: The top-k value to use for sampling. - - Raises: - ValueError: If the requested tokens exceed the context window. - RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt. - - Returns: - Response object containing the generated text. - """ completion_id = f"cmpl-{str(uuid.uuid4())}" - created= int(time.time()) - text = b"" + created = int(time.time()) completion_tokens = [] - last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n) - prompt_tokens = self.tokenize(prompt.encode("utf-8")) if len(prompt_tokens) + max_tokens > llama_cpp.llama_n_ctx(self.ctx): @@ -198,14 +174,15 @@ class Llama: stop = [s.encode("utf-8") for s in stop] finish_reason = None - for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty): + for token in self._generate( + prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty + ): if token == llama_cpp.llama_token_eos(): finish_reason = "stop" break - text += self.detokenize([token]) - last_n_tokens.append(token) completion_tokens.append(token) + text = self.detokenize(completion_tokens) any_stop = [s for s in stop if s in text] if len(any_stop) > 0: first_stop = any_stop[0] @@ -213,9 +190,55 @@ class Llama: finish_reason = "stop" break + if stream: + start = len(self.detokenize(completion_tokens[:-1])) + longest = 0 + for s in stop: + for i in range(len(s), 0, -1): + if s[-i:] == text[-i:]: + if i > longest: + longest = i + break + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": text[start : len(text) - longest].decode("utf-8"), + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + } + if finish_reason is None: finish_reason = "length" + if stream: + if finish_reason == "stop": + start = len(self.detokenize(completion_tokens[:-1])) + text = text[start:].decode("utf-8") + else: + text = "" + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + return + text = text.decode("utf-8") if echo: @@ -229,7 +252,7 @@ class Llama: self.ctx, )[:logprobs] - return { + yield { "id": completion_id, "object": "text_completion", "created": created, @@ -249,5 +272,58 @@ class Llama: }, } + def __call__( + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 16, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: List[str] = [], + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + ): + """Generate text from a prompt. + + Args: + prompt: The prompt to generate text from. + suffix: A suffix to append to the generated text. If None, no suffix is appended. + max_tokens: The maximum number of tokens to generate. + temperature: The temperature to use for sampling. + top_p: The top-p value to use for sampling. + logprobs: The number of logprobs to return. If None, no logprobs are returned. + echo: Whether to echo the prompt. + stop: A list of strings to stop generation when encountered. + repeat_penalty: The penalty to apply to repeated tokens. + top_k: The top-k value to use for sampling. + stream: Whether to stream the results. + + Raises: + ValueError: If the requested tokens exceed the context window. + RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt. + + Returns: + Response object containing the generated text. + """ + call = self._call( + prompt=prompt, + suffix=suffix, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + stop=stop, + repeat_penalty=repeat_penalty, + top_k=top_k, + stream=stream, + ) + if stream: + return call + return next(call) + def __del__(self): llama_cpp.llama_free(self.ctx)