Allow model to tokenize strings longer than context length and set add_bos. Closes #92

This commit is contained in:
Andrei Betlen 2023-05-12 14:28:22 -04:00
parent 8740ddc58e
commit 7a536e86c2
2 changed files with 18 additions and 4 deletions

View file

@ -174,7 +174,9 @@ class Llama:
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: def tokenize(
self, text: bytes, add_bos: bool = True
) -> List[llama_cpp.llama_token]:
"""Tokenize a string. """Tokenize a string.
Args: Args:
@ -194,10 +196,22 @@ class Llama:
text, text,
tokens, tokens,
n_ctx, n_ctx,
llama_cpp.c_bool(True), llama_cpp.c_bool(add_bos),
) )
if int(n_tokens) < 0: if int(n_tokens) < 0:
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') n_tokens = abs(n_tokens)
tokens = (llama_cpp.llama_token * int(n_tokens))()
n_tokens = llama_cpp.llama_tokenize(
self.ctx,
text,
tokens,
llama_cpp.c_int(n_tokens),
llama_cpp.c_bool(add_bos),
)
if n_tokens < 0:
raise RuntimeError(
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
)
return list(tokens[:n_tokens]) return list(tokens[:n_tokens])
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes: def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:

View file

@ -350,7 +350,7 @@ def llama_tokenize(
tokens, # type: Array[llama_token] tokens, # type: Array[llama_token]
n_max_tokens: c_int, n_max_tokens: c_int,
add_bos: c_bool, add_bos: c_bool,
) -> c_int: ) -> int:
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)