From 6153baab2d2ac7a2c6ce9caa60474d84cf78dca6 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 14 Apr 2023 09:59:33 -0400 Subject: [PATCH] Clean up logprobs implementation --- llama_cpp/llama.py | 106 +++++++++++++++++---------------------------- 1 file changed, 39 insertions(+), 67 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ae25137..ecfd2f4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -351,55 +351,19 @@ class Llama: else: stop_sequences = [] - text_offset = 0 - text_offsets: List[int] = [] - token_logprobs: List[float] = [] - tokens: List[str] = [] - top_logprobs: List[Dict[str, float]] = [] - - self.reset() - self.eval(prompt_tokens) - if logprobs is not None and self.params.logits_all is False: raise ValueError( "logprobs is not supported for models created with logits_all=False" ) - if logprobs is not None: - token_strs = [ - self.detokenize([token]).decode("utf-8") for token in prompt_tokens - ] - logprobs_all = [ - [Llama.logit_to_logprob(logit) for logit in row] - for row in self.all_logits - ] - for token, token_str, logprobs_token in zip( - prompt_tokens, token_strs, logprobs_all - ): - text_offsets.append(text_offset) - text_offset += len(token_str) - tokens.append(token_str) - sorted_logprobs = list( - sorted( - zip(logprobs_token, range(len(logprobs_token))), reverse=True - ) - ) - token_logprobs.append(sorted_logprobs[int(token)][0]) - top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) - top_logprobs.append(top_logprob) - finish_reason = "length" - while True: - token = self.sample( - top_k=top_k, - top_p=top_p, - temp=temperature, - repeat_penalty=repeat_penalty, - ) + for token in self.generate( + prompt_tokens, + top_k=top_k, + top_p=top_p, + temp=temperature, + repeat_penalty=repeat_penalty, + ): if token == llama_cpp.llama_token_eos(): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -443,34 +407,10 @@ class Llama: ], } - if logprobs is not None: - # TODO: Confirm wether this should happen before or after - # next eval. - token_str = self.detokenize([token]).decode("utf-8") - text_offsets.append(text_offset) - text_offset += len(token_str) - tokens.append(token_str) - logprobs_token = [ - Llama.logit_to_logprob(logit) for logit in self.all_logits[-1] - ] - sorted_logprobs = list( - sorted( - zip(logprobs_token, range(len(logprobs_token))), reverse=True - ) - ) - token_logprobs.append(sorted_logprobs[int(token)][0]) - top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: logprobs_token[int(token)]}) - top_logprobs.append(top_logprob) - if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) finish_reason = "length" break - self.eval([token]) if stream: yield { @@ -499,6 +439,38 @@ class Llama: logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: + text_offset = 0 + text_offsets: List[int] = [] + token_logprobs: List[float] = [] + tokens: List[str] = [] + top_logprobs: List[Dict[str, float]] = [] + + all_tokens = prompt_tokens + completion_tokens + all_token_strs = [ + self.detokenize([token]).decode("utf-8") for token in all_tokens + ] + all_logprobs = [ + [Llama.logit_to_logprob(logit) for logit in row] + for row in self.all_logits + ] + for token, token_str, logprobs_token in zip( + all_tokens, all_token_strs, all_logprobs + ): + text_offsets.append(text_offset) + text_offset += len(token_str) + tokens.append(token_str) + sorted_logprobs = list( + sorted( + zip(logprobs_token, range(len(logprobs_token))), reverse=True + ) + ) + token_logprobs.append(sorted_logprobs[int(token)][0]) + top_logprob = { + self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) + top_logprobs.append(top_logprob) logprobs_or_none = { "tokens": tokens, "text_offset": text_offsets,