diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e151c95..51bb8b2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -63,6 +63,11 @@ class Llama: self.params.embedding = embedding self.last_n_tokens_size = last_n_tokens_size + self.last_n_tokens_data = deque( + [llama_cpp.llama_token(0)] * self.last_n_tokens_size, + maxlen=self.last_n_tokens_size, + ) + self.tokens_consumed = 0 self.n_batch = n_batch self.n_threads = n_threads or multiprocessing.cpu_count() @@ -115,6 +120,67 @@ class Llama: output += llama_cpp.llama_token_to_str(self.ctx, token) return output + def reset(self): + """Reset the model state.""" + self.last_n_tokens_data.extend( + [llama_cpp.llama_token(0)] * self.last_n_tokens_size + ) + self.tokens_consumed = 0 + + def eval(self, tokens: Sequence[llama_cpp.llama_token]): + """Evaluate a list of tokens. + + Args: + tokens: The list of tokens to evaluate. + """ + assert self.ctx is not None + n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) + for i in range(0, len(tokens), self.n_batch): + batch = tokens[i : min(len(tokens), i + self.n_batch)] + n_past = min(n_ctx - len(batch), self.tokens_consumed) + return_code = llama_cpp.llama_eval( + ctx=self.ctx, + tokens=(llama_cpp.llama_token * len(batch))(*batch), + n_tokens=llama_cpp.c_int(len(batch)), + n_past=llama_cpp.c_int(n_past), + n_threads=llama_cpp.c_int(self.n_threads), + ) + if int(return_code) != 0: + raise RuntimeError(f"llama_eval returned {return_code}") + self.last_n_tokens_data.extend(batch) + self.tokens_consumed += len(batch) + + def sample( + self, + top_k: int, + top_p: float, + temp: float, + repeat_penalty: float, + ): + """Sample a token from the model. + + Args: + top_k: The top-k sampling parameter. + top_p: The top-p sampling parameter. + temp: The temperature parameter. + repeat_penalty: The repeat penalty parameter. + + Returns: + The sampled token. + """ + assert self.ctx is not None + return llama_cpp.llama_sample_top_p_top_k( + ctx=self.ctx, + last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( + *self.last_n_tokens_data + ), + last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size), + top_k=llama_cpp.c_int(top_k), + top_p=llama_cpp.c_float(top_p), + temp=llama_cpp.c_float(temp), + repeat_penalty=llama_cpp.c_float(repeat_penalty), + ) + def generate( self, tokens: Sequence[llama_cpp.llama_token], @@ -125,7 +191,7 @@ class Llama: ) -> Generator[ llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None ]: - """Generate tokens. + """Create a generator of tokens from a prompt. Examples: >>> llama = Llama("models/ggml-7b.bin") @@ -149,37 +215,14 @@ class Llama: top_p = 0.0 top_k = 1 assert self.ctx is not None - n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) - n_tokens = 0 - last_n_tokens = deque( - [llama_cpp.llama_token(0)] * self.last_n_tokens_size, - maxlen=self.last_n_tokens_size, - ) + self.reset() while True: - for i in range(0, len(tokens), self.n_batch): - batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), n_tokens) - return_code = llama_cpp.llama_eval( - ctx=self.ctx, - tokens=(llama_cpp.llama_token * len(batch))(*batch), - n_tokens=llama_cpp.c_int(len(batch)), - n_past=llama_cpp.c_int(n_past), - n_threads=llama_cpp.c_int(self.n_threads), - ) - if int(return_code) != 0: - raise RuntimeError(f"llama_eval returned {return_code}") - last_n_tokens.extend(batch) - n_tokens += len(batch) - token = llama_cpp.llama_sample_top_p_top_k( - ctx=self.ctx, - last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( - *last_n_tokens - ), - last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size), - top_k=llama_cpp.c_int(top_k), - top_p=llama_cpp.c_float(top_p), - temp=llama_cpp.c_float(temp), - repeat_penalty=llama_cpp.c_float(repeat_penalty), + self.eval(tokens) + token = self.sample( + top_k=top_k, + top_p=top_p, + temp=temp, + repeat_penalty=repeat_penalty, ) tokens_or_none = yield token tokens = [token] @@ -197,7 +240,8 @@ class Llama: """ assert self.ctx is not None tokens = self.tokenize(input.encode("utf-8")) - next(self.generate(tokens, top_k=0, top_p=0.0, temp=1.0, repeat_penalty=1.0)) + self.reset() + self.eval(tokens) n_tokens = len(tokens) embedding = llama_cpp.llama_get_embeddings(self.ctx)[ : llama_cpp.llama_n_embd(self.ctx)