diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bf4caf7..58c32e9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -710,22 +710,56 @@ class Llama: # We want to avoid yielding any characters from # the generated text if they are part of a stop # sequence. - longest = 0 + first_stop_position = 0 for s in stop_sequences: for i in range(len(s), 0, -1): if all_text.endswith(s[:i]): - if i > longest: - longest = i + if i > first_stop_position: + first_stop_position = i break - offset = 0 + token_end_position = 0 remaining_tokens = completion_tokens[returned_tokens:] remaining_length = len(self.detokenize(remaining_tokens)) for token in remaining_tokens: - offset += len(self.detokenize([token])) - # Check if stop sequence is not in the token - if offset >= (remaining_length - longest - 1): + token_end_position += len(self.detokenize([token])) + # Check if stop sequence is in the token + if token_end_position >= (remaining_length - first_stop_position - 1): break + logprobs_or_none: Optional[CompletionLogprobs] = None + if logprobs is not None: + token_str = self.detokenize([token]).decode( + "utf-8", errors="ignore" + ) + text_offset = len(prompt) + len( + self.detokenize(completion_tokens[:returned_tokens]) + ) + token_offset = len(prompt_tokens) + returned_tokens + logits = self.eval_logits[token_offset - 1] + current_logprobs = Llama.logits_to_logprobs(logits) + sorted_logprobs = list( + sorted( + zip(current_logprobs, range(len(current_logprobs))), + reverse=True, + ) + ) + top_logprob = { + self.detokenize([llama_cpp.llama_token(i)]).decode( + "utf-8", errors="ignore" + ): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: current_logprobs[int(token)]}) + logprobs_or_none = { + "tokens": [ + self.detokenize([token]).decode( + "utf-8", errors="ignore" + ) + ], + "text_offset": [text_offset], + "token_logprobs": [sorted_logprobs[int(token)][0]], + "top_logprobs": [top_logprob], + } returned_tokens += 1 yield { "id": completion_id, @@ -738,7 +772,7 @@ class Llama: "utf-8", errors="ignore" ), "index": 0, - "logprobs": None, + "logprobs": logprobs_or_none, "finish_reason": None, } ], @@ -766,13 +800,48 @@ class Llama: else: end = len(all_text) - offset = 0 + token_end_position = 0 for token in remaining_tokens: - offset += len(self.detokenize([token])) - if offset >= end: + token_end_position += len(self.detokenize([token])) + + logprobs_or_none: Optional[CompletionLogprobs] = None + if logprobs is not None: + token_str = self.detokenize([token]).decode( + "utf-8", errors="ignore" + ) + text_offset = len(prompt) + len( + self.detokenize(completion_tokens[:returned_tokens]) + ) + token_offset = len(prompt_tokens) + returned_tokens - 1 + logits = self.eval_logits[token_offset] + current_logprobs = Llama.logits_to_logprobs(logits) + sorted_logprobs = list( + sorted( + zip(current_logprobs, range(len(current_logprobs))), + reverse=True, + ) + ) + top_logprob = { + self.detokenize([llama_cpp.llama_token(i)]).decode( + "utf-8", errors="ignore" + ): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: current_logprobs[int(token)]}) + logprobs_or_none = { + "tokens": [ + self.detokenize([token]).decode("utf-8", errors="ignore") + ], + "text_offset": [text_offset], + "token_logprobs": [sorted_logprobs[int(token)][0]], + "top_logprobs": [top_logprob], + } + + if token_end_position >= end: last_text = self.detokenize([token]) - if offset == end - 1: + if token_end_position == end - 1: break + returned_tokens += 1 yield { "id": completion_id, "object": "text_completion", @@ -781,10 +850,10 @@ class Llama: "choices": [ { "text": last_text[ - : len(last_text) - (offset - end) + : len(last_text) - (token_end_position - end) ].decode("utf-8", errors="ignore"), "index": 0, - "logprobs": None, + "logprobs": logprobs_or_none, "finish_reason": finish_reason, } ], @@ -802,7 +871,7 @@ class Llama: "utf-8", errors="ignore" ), "index": 0, - "logprobs": None, + "logprobs": logprobs_or_none, "finish_reason": finish_reason if returned_tokens == len(completion_tokens) else None, @@ -821,13 +890,19 @@ class Llama: logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: - text_offset = 0 + text_offset = 0 if echo else len(prompt) + token_offset = 0 if echo else len(prompt_tokens[1:]) text_offsets: List[int] = [] - token_logprobs: List[float] = [] + token_logprobs: List[Optional[float]] = [] tokens: List[str] = [] - top_logprobs: List[Dict[str, float]] = [] + top_logprobs: List[Optional[Dict[str, float]]] = [] + + if echo: + # Remove leading BOS token + all_tokens = prompt_tokens[1:] + completion_tokens + else: + all_tokens = completion_tokens - all_tokens = prompt_tokens + completion_tokens all_token_strs = [ self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens @@ -835,7 +910,7 @@ class Llama: all_logprobs = [ Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits - ] + ][token_offset:] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): @@ -848,14 +923,20 @@ class Llama: ) ) token_logprobs.append(sorted_logprobs[int(token)][0]) - top_logprob = { + top_logprob: Optional[Dict[str, float]] = { self.detokenize([llama_cpp.llama_token(i)]).decode( "utf-8", errors="ignore" ): logprob for logprob, i in sorted_logprobs[:logprobs] } - top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) + top_logprob.update({token_str: logprobs_token[int(token)]}) top_logprobs.append(top_logprob) + # Weird idosincracy of the OpenAI API where + # token_logprobs and top_logprobs are null for + # the first token. + if echo and len(all_tokens) > 0: + token_logprobs[0] = None + top_logprobs[0] = None logprobs_or_none = { "tokens": tokens, "text_offset": text_offsets,