diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4b3d36a..01a2e70 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -480,7 +480,7 @@ class Llama: Returns: The detokenized string. """ - return self.tokenizer_.detokenize(tokens, prev_tokens) + return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens) def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. @@ -1016,13 +1016,13 @@ class Llama: grammar=grammar, ): if token == self._token_eos: - text = self.detokenize(completion_tokens) + text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "stop" break completion_tokens.append(token) - all_text = self.detokenize(completion_tokens) + all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) # Contains multi-byte UTF8 for k, char in enumerate(all_text[-3:]): @@ -1046,7 +1046,7 @@ class Llama: if stream: remaining_tokens = completion_tokens[returned_tokens:] - remaining_text = self.detokenize(remaining_tokens) + remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) remaining_length = len(remaining_text) # We want to avoid yielding any characters from @@ -1068,17 +1068,17 @@ class Llama: for token in remaining_tokens: if token == self.token_bos(): continue - token_end_position += len(self.detokenize([token])) + token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])) # Check if stop sequence is in the token if token_end_position > ( remaining_length - first_stop_position ): break - token_str = self.detokenize([token]).decode( + token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( "utf-8", errors="ignore" ) text_offset = len(prompt) + len( - self.detokenize(completion_tokens[:returned_tokens]).decode( + self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( "utf-8", errors="ignore" ) ) @@ -1100,7 +1100,7 @@ class Llama: top_logprob.update({token_str: current_logprobs[int(token)]}) logprobs_or_none = { "tokens": [ - self.detokenize([token]).decode( + self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( "utf-8", errors="ignore" ) ], @@ -1116,7 +1116,7 @@ class Llama: "model": model_name, "choices": [ { - "text": self.detokenize([token]).decode( + "text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( "utf-8", errors="ignore" ), "index": 0, @@ -1130,7 +1130,7 @@ class Llama: decode_success = False for i in range(1, len(remaining_tokens) + 1): try: - bs = self.detokenize(remaining_tokens[:i]) + bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) ts = bs.decode("utf-8") decode_success = True break @@ -1165,14 +1165,14 @@ class Llama: } if len(completion_tokens) >= max_tokens: - text = self.detokenize(completion_tokens) + text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "length" break if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): - text = self.detokenize(completion_tokens) + text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "stop" if self.verbose: @@ -1180,7 +1180,7 @@ class Llama: if stream: remaining_tokens = completion_tokens[returned_tokens:] - all_text = self.detokenize(remaining_tokens) + all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: end = min(all_text.index(stop) for stop in any_stop) @@ -1189,7 +1189,7 @@ class Llama: token_end_position = 0 for token in remaining_tokens: - token_end_position += len(self.detokenize([token])) + token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])) logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: @@ -1199,7 +1199,7 @@ class Llama: "utf-8", errors="ignore" ) text_offset = len(prompt) + len( - self.detokenize(completion_tokens[:returned_tokens]) + self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 logits = self._scores[token_offset, :] @@ -1313,8 +1313,8 @@ class Llama: all_tokens = completion_tokens all_token_strs = [ - self.detokenize([token]).decode("utf-8", errors="ignore") - for token in all_tokens + self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore") + for i, token in enumerate(all_tokens) ] all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] # TODO: may be able to change this loop to use np.take_along_dim @@ -1339,7 +1339,7 @@ class Llama: ) token_logprobs.append(logprobs_token[int(token)]) top_logprob: Optional[Dict[str, float]] = { - self.detokenize([i]).decode("utf-8", errors="ignore"): logprob + self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: logprobs_token[int(token)]}) @@ -1594,6 +1594,8 @@ class Llama: logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, ) -> Union[ CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] ]: diff --git a/llama_cpp/llama_tokenizer.py b/llama_cpp/llama_tokenizer.py index c2aad47..5a8b13b 100644 --- a/llama_cpp/llama_tokenizer.py +++ b/llama_cpp/llama_tokenizer.py @@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC): def tokenize( self, text: bytes, add_bos: bool = True, special: bool = True ) -> List[int]: + """Tokenize the text into tokens. + + Args: + text: The text to tokenize. + add_bos: Whether to add a beginning of sequence token. + special: Whether to tokenize text literally or as special tokens.""" raise NotImplementedError @abc.abstractmethod def detokenize( self, tokens: List[int], prev_tokens: Optional[List[int]] = None ) -> bytes: + """Detokenize the tokens into text. + + Args: + tokens: The tokens to detokenize. + prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens.""" raise NotImplementedError @@ -37,10 +48,7 @@ class LlamaTokenizer(BaseLlamaTokenizer): def detokenize( self, tokens: List[int], prev_tokens: Optional[List[int]] = None ) -> bytes: - if prev_tokens is not None: - return self._model.detokenize(tokens[len(prev_tokens) :]) - else: - return self._model.detokenize(tokens) + return self._model.detokenize(tokens) def encode( self, text: str, add_bos: bool = True, special: bool = True @@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer): self, tokens: List[int], prev_tokens: Optional[List[int]] = None ) -> bytes: if prev_tokens is not None: - text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") + text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore") prev_text = self.hf_tokenizer.decode(prev_tokens).encode( "utf-8", errors="ignore" )