feat: Support SPM infill (#1492)
* Support SPM infill * typo-- * one less layer of parenthesis necessary * new required internals * manually add bos/eos if model requires it * add bos even when unknown This is identical behaviour to llama.cpp I guess any model that doesn't use BOS is recent enough to have the add_bos_token metadata. * don't add bos/eos on non-infill pre-tokenized prompt * add tokenizer hack to remove leading space in suffix * I keep forgetting metadata are strings * check if bos exists * add example * add cls/sep instead of bos/eos for WPM vocab * simplify * color-code filtered suffix --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
e342161371
commit
dbcf64cf07
3 changed files with 91 additions and 31 deletions
33
examples/high_level_api/high_level_api_infill.py
Normal file
33
examples/high_level_api/high_level_api_infill.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
|
||||||
|
parser.add_argument("-p", "--prompt", type=str, default="def add(")
|
||||||
|
parser.add_argument("-s", "--suffix", type=str, default="\n return sum\n\n")
|
||||||
|
parser.add_argument("-i", "--spm-infill", action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
llm = Llama(model_path=args.model, n_gpu_layers=-1, spm_infill=args.spm_infill)
|
||||||
|
|
||||||
|
output = llm.create_completion(
|
||||||
|
temperature = 0.0,
|
||||||
|
repeat_penalty = 1.0,
|
||||||
|
prompt = args.prompt,
|
||||||
|
suffix = args.suffix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Models sometimes repeat suffix in response, attempt to filter that
|
||||||
|
response = output["choices"][0]["text"]
|
||||||
|
response_stripped = response.rstrip()
|
||||||
|
unwanted_response_suffix = args.suffix.rstrip()
|
||||||
|
unwanted_response_length = len(unwanted_response_suffix)
|
||||||
|
|
||||||
|
filtered = False
|
||||||
|
if unwanted_response_suffix and response_stripped[-unwanted_response_length:] == unwanted_response_suffix:
|
||||||
|
response = response_stripped[:-unwanted_response_length]
|
||||||
|
filtered = True
|
||||||
|
|
||||||
|
print(f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m")
|
||||||
|
|
|
@ -170,6 +170,14 @@ class _LlamaModel:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
return llama_cpp.llama_token_eot(self.model)
|
return llama_cpp.llama_token_eot(self.model)
|
||||||
|
|
||||||
|
def add_bos_token(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_add_bos_token(self.model)
|
||||||
|
|
||||||
|
def add_eos_token(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_add_eos_token(self.model)
|
||||||
|
|
||||||
# Tokenization
|
# Tokenization
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
||||||
|
|
|
@ -115,6 +115,7 @@ class Llama:
|
||||||
type_k: Optional[int] = None,
|
type_k: Optional[int] = None,
|
||||||
type_v: Optional[int] = None,
|
type_v: Optional[int] = None,
|
||||||
# Misc
|
# Misc
|
||||||
|
spm_infill: bool = False,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
# Extra Params
|
# Extra Params
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
|
@ -185,6 +186,7 @@ class Llama:
|
||||||
verbose: Print verbose output to stderr.
|
verbose: Print verbose output to stderr.
|
||||||
type_k: KV cache data type for K (default: f16)
|
type_k: KV cache data type for K (default: f16)
|
||||||
type_v: KV cache data type for V (default: f16)
|
type_v: KV cache data type for V (default: f16)
|
||||||
|
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the model path does not exist.
|
ValueError: If the model path does not exist.
|
||||||
|
@ -343,6 +345,8 @@ class Llama:
|
||||||
self.lora_scale = lora_scale
|
self.lora_scale = lora_scale
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
|
||||||
|
self.spm_infill = spm_infill
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise ValueError(f"Model path does not exist: {model_path}")
|
raise ValueError(f"Model path does not exist: {model_path}")
|
||||||
|
|
||||||
|
@ -972,14 +976,33 @@ class Llama:
|
||||||
|
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
|
bos_token_id: int = self.token_bos()
|
||||||
|
cls_token_id: int = self._model.token_cls()
|
||||||
|
sep_token_id: int = self._model.token_sep()
|
||||||
prefix_token_id: int = self._model.token_prefix()
|
prefix_token_id: int = self._model.token_prefix()
|
||||||
middle_token_id: int = self._model.token_middle()
|
middle_token_id: int = self._model.token_middle()
|
||||||
suffix_token_id: int = self._model.token_suffix()
|
suffix_token_id: int = self._model.token_suffix()
|
||||||
|
add_space_prefix: bool = self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
|
||||||
|
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
|
||||||
|
eos_tokens: List[int] = [sep_token_id if sep_token_id != -1 else self.token_eos()]
|
||||||
|
|
||||||
|
if (isinstance(prompt, list) and suffix is None) or self._model.add_bos_token() == 0 or bos_tokens[:1] == [-1]:
|
||||||
|
bos_tokens = []
|
||||||
|
|
||||||
|
if (isinstance(prompt, list) and suffix is None) or (self._model.add_eos_token() != 1 and sep_token_id == -1):
|
||||||
|
eos_tokens = []
|
||||||
|
|
||||||
|
suffix_space_prefix: int = 0
|
||||||
|
# Tokenizer hack to remove leading space
|
||||||
|
if add_space_prefix and suffix_token_id >= 0 and suffix:
|
||||||
|
suffix = "☺" + suffix
|
||||||
|
suffix_space_prefix = 2
|
||||||
|
|
||||||
# If prompt is empty, initialize completion with BOS token to avoid
|
# If prompt is empty, initialize completion with BOS token to avoid
|
||||||
# detokenization including a space at the beginning of the completion
|
# detokenization including a space at the beginning of the completion
|
||||||
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
|
completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
|
||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens: List[int] = (
|
prefix_tokens: List[int] = (
|
||||||
(
|
(
|
||||||
[prefix_token_id]
|
[prefix_token_id]
|
||||||
if prefix_token_id >= 0 and suffix is not None
|
if prefix_token_id >= 0 and suffix is not None
|
||||||
|
@ -988,38 +1011,33 @@ class Llama:
|
||||||
+
|
+
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None))
|
self.tokenize(prompt.encode("utf-8"), add_bos=False, special=(prefix_token_id < 0 or suffix is None))
|
||||||
if prompt != ""
|
if prompt != ""
|
||||||
else (
|
else []
|
||||||
[]
|
|
||||||
if prefix_token_id >= 0 and suffix is not None
|
|
||||||
else [self.token_bos()]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else prompt
|
else prompt
|
||||||
)
|
)
|
||||||
+
|
|
||||||
(
|
|
||||||
(
|
|
||||||
[suffix_token_id]
|
|
||||||
+
|
|
||||||
(
|
|
||||||
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)
|
|
||||||
if suffix
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if suffix_token_id >= 0 and suffix is not None
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
+
|
|
||||||
(
|
|
||||||
[middle_token_id]
|
|
||||||
if middle_token_id >= 0 and suffix is not None
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
suffix_tokens: List[int] = (
|
||||||
|
(
|
||||||
|
[suffix_token_id]
|
||||||
|
+
|
||||||
|
(
|
||||||
|
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[suffix_space_prefix:]
|
||||||
|
if suffix
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if suffix_token_id >= 0 and suffix is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
middle_tokens: List[int] = (
|
||||||
|
[middle_token_id]
|
||||||
|
if middle_token_id >= 0 and suffix is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
prompt_tokens: List[int] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens) if self.spm_infill else (prefix_tokens + suffix_tokens + middle_tokens)) + eos_tokens
|
||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_tokens: int = 0
|
returned_tokens: int = 0
|
||||||
stop = (
|
stop = (
|
||||||
|
@ -1176,7 +1194,7 @@ class Llama:
|
||||||
# not sure how to handle this branch when dealing
|
# not sure how to handle this branch when dealing
|
||||||
# with CJK output, so keep it unchanged
|
# with CJK output, so keep it unchanged
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
if token == self.token_bos():
|
if token == bos_token_id:
|
||||||
continue
|
continue
|
||||||
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
||||||
# Check if stop sequence is in the token
|
# Check if stop sequence is in the token
|
||||||
|
@ -1303,7 +1321,7 @@ class Llama:
|
||||||
|
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
if token == self.token_bos():
|
if token == bos_token_id:
|
||||||
continue
|
continue
|
||||||
token_str = self.detokenize([token]).decode(
|
token_str = self.detokenize([token]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
|
@ -1431,7 +1449,7 @@ class Llama:
|
||||||
for idx, (token, token_str, logprobs_token) in enumerate(
|
for idx, (token, token_str, logprobs_token) in enumerate(
|
||||||
zip(all_tokens, all_token_strs, all_logprobs)
|
zip(all_tokens, all_token_strs, all_logprobs)
|
||||||
):
|
):
|
||||||
if token == self.token_bos():
|
if token == bos_token_id:
|
||||||
continue
|
continue
|
||||||
text_offsets.append(
|
text_offsets.append(
|
||||||
text_offset
|
text_offset
|
||||||
|
@ -1858,6 +1876,7 @@ class Llama:
|
||||||
type_k=self.context_params.type_k,
|
type_k=self.context_params.type_k,
|
||||||
type_v=self.context_params.type_v,
|
type_v=self.context_params.type_v,
|
||||||
# Misc
|
# Misc
|
||||||
|
spm_infill=self.spm_infill,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue