From 350a1769e188629c484bb127415cd82751a86597 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 1 May 2023 14:47:55 -0400 Subject: [PATCH] Update sampling api --- llama_cpp/llama.py | 119 ++++++++++++++++++++++++++++++++++------- llama_cpp/llama_cpp.py | 22 +++++--- 2 files changed, 113 insertions(+), 28 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4e3c3aa..b38f2bb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -127,7 +127,9 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) - self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx) + self.eval_logits: Deque[List[llama_cpp.c_float]] = deque( + maxlen=n_ctx if logits_all else 1 + ) self.cache: Optional[LlamaCache] = None @@ -236,17 +238,90 @@ class Llama: ) if int(return_code) != 0: raise RuntimeError(f"llama_eval returned {return_code}") + # Save tokens self.eval_tokens.extend(batch) - if self.params.logits_all: - n_vocab = llama_cpp.llama_n_vocab(self.ctx) - cols = int(n_vocab) - rows = n_tokens - logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [ - [logits_view[i * cols + j] for j in range(cols)] - for i in range(rows) - ] - self.eval_logits.extend(logits) + # Save logits + rows = n_tokens if self.params.logits_all else 1 + n_vocab = llama_cpp.llama_n_vocab(self.ctx) + cols = int(n_vocab) + logits_view = llama_cpp.llama_get_logits(self.ctx) + logits: List[List[llama_cpp.c_float]] = [ + [logits_view[i * cols + j] for j in range(cols)] for i in range(rows) + ] + self.eval_logits.extend(logits) + + def _sample_top_p_top_k( + self, + last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] + last_n_tokens_size: llama_cpp.c_int, + top_k: llama_cpp.c_int, + top_p: llama_cpp.c_float, + temp: llama_cpp.c_float, + repeat_penalty: llama_cpp.c_float, + ): + assert self.ctx is not None + assert len(self.eval_logits) > 0 + n_vocab = int(llama_cpp.llama_n_vocab(self.ctx)) + logits = self.eval_logits[-1] + data = (llama_cpp.llama_token_data * n_vocab)( + *[ + llama_cpp.llama_token_data( + id=llama_cpp.llama_token(i), + logit=logits[i], + p=llama_cpp.c_float(0.0), + ) + for i in range(n_vocab) + ] + ) + size = llama_cpp.c_size_t(n_vocab) + sorted = False + candidates = llama_cpp.llama_token_data_array( + data=data, + size=size, + sorted=sorted, + ) + llama_cpp.llama_sample_repetition_penalty( + ctx=self.ctx, + last_tokens_data=last_n_tokens_data, + last_tokens_size=last_n_tokens_size, + candidates=llama_cpp.ctypes.pointer(candidates), + penalty=repeat_penalty, + ) + if temp == 0.0: + return llama_cpp.llama_sample_token_greedy( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + ) + else: + llama_cpp.llama_sample_top_k( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + k=top_k, + ) + llama_cpp.llama_sample_tail_free( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + z=llama_cpp.c_float(1.0), + ) + llama_cpp.llama_sample_typical( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + p=llama_cpp.c_float(1.0) + ) + llama_cpp.llama_sample_top_p( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + p=top_p, + ) + llama_cpp.llama_sample_temperature( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + temp=temp, + ) + return llama_cpp.llama_sample_token( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + ) def sample( self, @@ -270,8 +345,7 @@ class Llama: last_n_tokens_data = [llama_cpp.llama_token(0)] * max( 0, self.last_n_tokens_size - len(self.eval_tokens) ) + list(self.eval_tokens)[-self.last_n_tokens_size :] - return llama_cpp.llama_sample_top_p_top_k( - ctx=self.ctx, + return self._sample_top_p_top_k( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data ), @@ -470,15 +544,15 @@ class Llama: all_text = self.detokenize(completion_tokens) # Contains multi-byte UTF8 - for k,char in enumerate(all_text[-3:]): + for k, char in enumerate(all_text[-3:]): k = 3 - k - for num,pattern in [(2, 192), (3, 224), (4, 240)]: + for num, pattern in [(2, 192), (3, 224), (4, 240)]: # Bitwise AND check - if (num > k and pattern & char == pattern): + if num > k and pattern & char == pattern: multibyte_fix = num - k # Stop incomplete bytes from passing - if (multibyte_fix > 0): + if multibyte_fix > 0: multibyte_fix -= 1 continue @@ -531,7 +605,9 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[returned_characters:].decode("utf-8", errors="ignore"), + "text": text[returned_characters:].decode( + "utf-8", errors="ignore" + ), "index": 0, "logprobs": None, "finish_reason": finish_reason, @@ -558,7 +634,8 @@ class Llama: all_tokens = prompt_tokens + completion_tokens all_token_strs = [ - self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens + self.detokenize([token]).decode("utf-8", errors="ignore") + for token in all_tokens ] all_logprobs = [ [Llama.logit_to_logprob(logit) for logit in row] @@ -577,7 +654,9 @@ class Llama: ) token_logprobs.append(sorted_logprobs[int(token)][0]) top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob + self.detokenize([llama_cpp.llama_token(i)]).decode( + "utf-8", errors="ignore" + ): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 95e9cfa..b4717bf 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None # @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int): +def llama_sample_top_k( + ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_top_k(ctx, candidates, k, min_keep) @@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [ llama_context_p, llama_token_data_array_p, c_int, - c_int, + c_size_t, ] _lib.llama_sample_top_k.restype = None # @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): +def llama_sample_top_p( + ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_top_p(ctx, candidates, p, min_keep) @@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_top_p.restype = None # @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. def llama_sample_tail_free( - ctx: llama_context_p, candidates, z: c_float, min_keep: c_int + ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1) ): return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep) @@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_tail_free.restype = None # @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): +def llama_sample_typical( + ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_typical(ctx, candidates, p, min_keep) @@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_typical.restype = None