diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a8c3f9a..5acc112 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -955,18 +955,53 @@ class Llama: completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) + prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix())) + middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle())) + suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix())) # If prompt is empty, initialize completion with BOS token to avoid # detokenization including a space at the beginning of the completion completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[int] = ( ( - self.tokenize(prompt.encode("utf-8"), special=True) - if prompt != "" - else [self.token_bos()] + [prefix_token_id] + if prefix_token_id >= 0 and suffix is not None + else [] + ) + + + ( + ( + self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None)) + if prompt != "" + else ( + [] + if prefix_token_id >= 0 and suffix is not None + else [self.token_bos()] + ) + ) + if isinstance(prompt, str) + else prompt + ) + + + ( + ( + [suffix_token_id] + + + ( + self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False) + if suffix + else [] + ) + ) + if suffix_token_id >= 0 and suffix is not None + else [] + ) + + + ( + [middle_token_id] + if middle_token_id >= 0 and suffix is not None + else [] ) - if isinstance(prompt, str) - else prompt ) text: bytes = b"" returned_tokens: int = 0 @@ -1346,7 +1381,7 @@ class Llama: if echo: text_str = prompt + text_str - if suffix is not None: + if suffix_token_id < 0 and suffix is not None: text_str = text_str + suffix logprobs_or_none: Optional[CompletionLogprobs] = None