feat: fill-in-middle support (#1386)

* Proper fill-in-middle support

Use prefix/middle/suffix tokens when metadata is present in GGUF, like f.ex. in [this](https://huggingface.co/CISCai/CodeQwen1.5-7B-Chat-SOTA-GGUF) one.

* fall back to internal prefix/middle/suffix id

In some cases llama.cpp will make a guess at fim tokens, use them if there's no metadata.

* typo--

* don't insert special tokens that are not there in suffix

Note: add_bos is misnamed, it's actually add_special and can cause several special tokens to be added to the token list (the special parameter is actually parse_special).

* don't add/parse any special tokens when using fim

I've left original behavior when no fim tokens are found, but this should perhaps be re-evaluated.

* don't append suffix to prompt_tokens unless fim tokens are detected

* make sure we only do this for fim

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Sigbjørn Skjæret 2024-05-08 08:26:22 +02:00 committed by GitHub
parent 228949c1f7
commit 4a7122d22f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -955,19 +955,54 @@ class Llama:
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix()))
middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle()))
suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix()))
# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = (
(
self.tokenize(prompt.encode("utf-8"), special=True)
[prefix_token_id]
if prefix_token_id >= 0 and suffix is not None
else []
)
+
(
(
self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None))
if prompt != ""
else (
[]
if prefix_token_id >= 0 and suffix is not None
else [self.token_bos()]
)
)
if isinstance(prompt, str)
else prompt
)
+
(
(
[suffix_token_id]
+
(
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)
if suffix
else []
)
)
if suffix_token_id >= 0 and suffix is not None
else []
)
+
(
[middle_token_id]
if middle_token_id >= 0 and suffix is not None
else []
)
)
text: bytes = b""
returned_tokens: int = 0
stop = (
@ -1346,7 +1381,7 @@ class Llama:
if echo:
text_str = prompt + text_str
if suffix is not None:
if suffix_token_id < 0 and suffix is not None:
text_str = text_str + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None