Allow model to tokenize strings longer than context length and set add_bos. Closes #92
This commit is contained in:
parent
8740ddc58e
commit
7a536e86c2
2 changed files with 18 additions and 4 deletions
|
@ -174,7 +174,9 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
def tokenize(
|
||||||
|
self, text: bytes, add_bos: bool = True
|
||||||
|
) -> List[llama_cpp.llama_token]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -194,10 +196,22 @@ class Llama:
|
||||||
text,
|
text,
|
||||||
tokens,
|
tokens,
|
||||||
n_ctx,
|
n_ctx,
|
||||||
llama_cpp.c_bool(True),
|
llama_cpp.c_bool(add_bos),
|
||||||
)
|
)
|
||||||
if int(n_tokens) < 0:
|
if int(n_tokens) < 0:
|
||||||
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
|
n_tokens = abs(n_tokens)
|
||||||
|
tokens = (llama_cpp.llama_token * int(n_tokens))()
|
||||||
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
|
self.ctx,
|
||||||
|
text,
|
||||||
|
tokens,
|
||||||
|
llama_cpp.c_int(n_tokens),
|
||||||
|
llama_cpp.c_bool(add_bos),
|
||||||
|
)
|
||||||
|
if n_tokens < 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
||||||
|
)
|
||||||
return list(tokens[:n_tokens])
|
return list(tokens[:n_tokens])
|
||||||
|
|
||||||
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
|
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
|
||||||
|
|
|
@ -350,7 +350,7 @@ def llama_tokenize(
|
||||||
tokens, # type: Array[llama_token]
|
tokens, # type: Array[llama_token]
|
||||||
n_max_tokens: c_int,
|
n_max_tokens: c_int,
|
||||||
add_bos: c_bool,
|
add_bos: c_bool,
|
||||||
) -> c_int:
|
) -> int:
|
||||||
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
|
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue