diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 47fa543..4295ba7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -174,7 +174,9 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: + def tokenize( + self, text: bytes, add_bos: bool = True + ) -> List[llama_cpp.llama_token]: """Tokenize a string. Args: @@ -194,10 +196,22 @@ class Llama: text, tokens, n_ctx, - llama_cpp.c_bool(True), + llama_cpp.c_bool(add_bos), ) if int(n_tokens) < 0: - raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') + n_tokens = abs(n_tokens) + tokens = (llama_cpp.llama_token * int(n_tokens))() + n_tokens = llama_cpp.llama_tokenize( + self.ctx, + text, + tokens, + llama_cpp.c_int(n_tokens), + llama_cpp.c_bool(add_bos), + ) + if n_tokens < 0: + raise RuntimeError( + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + ) return list(tokens[:n_tokens]) def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index e60558c..870eced 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -350,7 +350,7 @@ def llama_tokenize( tokens, # type: Array[llama_token] n_max_tokens: c_int, add_bos: c_bool, -) -> c_int: +) -> int: return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)