diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index da5b0e3..564c6c3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -15,9 +15,7 @@ class LlamaCache: """Cache for a llama.cpp model.""" def __init__(self, capacity_bytes: int = (2 << 30)): - self.cache_state: OrderedDict[ - Tuple[llama_cpp.llama_token, ...], "LlamaState" - ] = OrderedDict() + self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() self.capacity_bytes = capacity_bytes @property @@ -26,8 +24,8 @@ class LlamaCache: def _find_longest_prefix_key( self, - key: Tuple[llama_cpp.llama_token, ...], - ) -> Optional[Tuple[llama_cpp.llama_token, ...]]: + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: min_len = 0 min_key = None keys = ( @@ -39,7 +37,7 @@ class LlamaCache: min_key = k return min_key - def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState": + def __getitem__(self, key: Sequence[int]) -> "LlamaState": key = tuple(key) _key = self._find_longest_prefix_key(key) if _key is None: @@ -48,10 +46,10 @@ class LlamaCache: self.cache_state.move_to_end(_key) return value - def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: + def __contains__(self, key: Sequence[int]) -> bool: return self._find_longest_prefix_key(tuple(key)) is not None - def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): + def __setitem__(self, key: Sequence[int], value: "LlamaState"): key = tuple(key) if key in self.cache_state: del self.cache_state[key] @@ -63,7 +61,7 @@ class LlamaCache: class LlamaState: def __init__( self, - eval_tokens: Deque[llama_cpp.llama_token], + eval_tokens: Deque[int], eval_logits: Deque[List[float]], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state_size: int, @@ -141,7 +139,7 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) - self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) + self.eval_tokens: Deque[int] = deque(maxlen=n_ctx) self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1) self.cache: Optional[LlamaCache] = None @@ -176,9 +174,7 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - def tokenize( - self, text: bytes, add_bos: bool = True - ) -> List[llama_cpp.llama_token]: + def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. Args: @@ -197,7 +193,7 @@ class Llama: self.ctx, text, tokens, - n_ctx, + llama_cpp.c_int(n_ctx), llama_cpp.c_bool(add_bos), ) if int(n_tokens) < 0: @@ -216,7 +212,7 @@ class Llama: ) return list(tokens[:n_tokens]) - def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes: + def detokenize(self, tokens: List[int]) -> bytes: """Detokenize a list of tokens. Args: @@ -228,7 +224,9 @@ class Llama: assert self.ctx is not None output = b"" for token in tokens: - output += llama_cpp.llama_token_to_str(self.ctx, token) + output += llama_cpp.llama_token_to_str( + self.ctx, llama_cpp.llama_token(token) + ) return output def set_cache(self, cache: Optional[LlamaCache]): @@ -244,7 +242,7 @@ class Llama: self.eval_tokens.clear() self.eval_logits.clear() - def eval(self, tokens: Sequence[llama_cpp.llama_token]): + def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. Args: @@ -458,7 +456,7 @@ class Llama: def generate( self, - tokens: Sequence[llama_cpp.llama_token], + tokens: Sequence[int], top_k: int = 40, top_p: float = 0.95, temp: float = 0.80, @@ -470,9 +468,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - ) -> Generator[ - llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None - ]: + ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. Examples: @@ -617,14 +613,14 @@ class Llama: assert self.ctx is not None completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) - completion_tokens: List[llama_cpp.llama_token] = [] + completion_tokens: List[int] = [] # Add blank space to start of prompt to match OG llama tokenizer - prompt_tokens: List[llama_cpp.llama_token] = self.tokenize( - b" " + prompt.encode("utf-8") - ) + prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8")) text: bytes = b"" returned_tokens: int = 0 - stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + stop = ( + stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + ) model_name: str = model if model is not None else self.model_path if self.verbose: @@ -724,7 +720,9 @@ class Llama: for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position >= (remaining_length - first_stop_position - 1): + if token_end_position >= ( + remaining_length - first_stop_position - 1 + ): break logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: @@ -744,7 +742,7 @@ class Llama: ) ) top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode( + self.detokenize([i]).decode( "utf-8", errors="ignore" ): logprob for logprob, i in sorted_logprobs[:logprobs] @@ -822,9 +820,7 @@ class Llama: ) ) top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode( - "utf-8", errors="ignore" - ): logprob + self.detokenize([i]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: current_logprobs[int(token)]}) @@ -924,9 +920,7 @@ class Llama: ) token_logprobs.append(sorted_logprobs[int(token)][0]) top_logprob: Optional[Dict[str, float]] = { - self.detokenize([llama_cpp.llama_token(i)]).decode( - "utf-8", errors="ignore" - ): logprob + self.detokenize([i]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: logprobs_token[int(token)]}) @@ -1188,7 +1182,9 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ - stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + stop = ( + stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + ) chat_history = "".join( f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' for message in messages @@ -1296,17 +1292,17 @@ class Llama: raise RuntimeError("Failed to set llama state data") @staticmethod - def token_eos() -> llama_cpp.llama_token: + def token_eos() -> int: """Return the end-of-sequence token.""" return llama_cpp.llama_token_eos() @staticmethod - def token_bos() -> llama_cpp.llama_token: + def token_bos() -> int: """Return the beginning-of-sequence token.""" return llama_cpp.llama_token_bos() @staticmethod - def token_nl() -> llama_cpp.llama_token: + def token_nl() -> int: """Return the newline token.""" return llama_cpp.llama_token_nl() @@ -1317,9 +1313,7 @@ class Llama: return [math.log(x / sum_exps) for x in exps] @staticmethod - def longest_token_prefix( - a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token] - ): + def longest_token_prefix(a: Sequence[int], b: Sequence[int]): longest_prefix = 0 for _a, _b in zip(a, b): if _a == _b: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 24ab40a..0dcb16c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -44,13 +44,13 @@ def _load_shared_library(lib_base_name: str): _base_path = _lib.parent.resolve() _lib_paths = [_lib.resolve()] - cdll_args = dict() # type: ignore + cdll_args = dict() # type: ignore # Add the library directory to the DLL search path on Windows (if needed) if sys.platform == "win32" and sys.version_info >= (3, 8): os.add_dll_directory(str(_base_path)) if "CUDA_PATH" in os.environ: - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"],"bin")) - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"],"lib")) + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) cdll_args["winmode"] = 0 # Try to load the shared library, handling potential errors @@ -194,7 +194,7 @@ _lib.llama_init_from_file.restype = llama_context_p # Frees all allocated memory def llama_free(ctx: llama_context_p): - _lib.llama_free(ctx) + return _lib.llama_free(ctx) _lib.llama_free.argtypes = [llama_context_p] @@ -206,7 +206,7 @@ _lib.llama_free.restype = None # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given def llama_model_quantize( fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int -) -> c_int: +) -> int: return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread) @@ -225,7 +225,7 @@ def llama_apply_lora_from_file( path_lora: c_char_p, path_base_model: c_char_p, n_threads: c_int, -) -> c_int: +) -> int: return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads) @@ -234,7 +234,7 @@ _lib.llama_apply_lora_from_file.restype = c_int # Returns the number of tokens in the KV cache -def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: +def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: return _lib.llama_get_kv_cache_token_count(ctx) @@ -253,7 +253,7 @@ _lib.llama_set_rng_seed.restype = None # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens -def llama_get_state_size(ctx: llama_context_p) -> c_size_t: +def llama_get_state_size(ctx: llama_context_p) -> int: return _lib.llama_get_state_size(ctx) @@ -293,7 +293,7 @@ def llama_load_session_file( tokens_out, # type: Array[llama_token] n_token_capacity: c_size_t, n_token_count_out, # type: _Pointer[c_size_t] -) -> c_size_t: +) -> int: return _lib.llama_load_session_file( ctx, path_session, tokens_out, n_token_capacity, n_token_count_out ) @@ -314,7 +314,7 @@ def llama_save_session_file( path_session: bytes, tokens, # type: Array[llama_token] n_token_count: c_size_t, -) -> c_size_t: +) -> int: return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count) @@ -337,7 +337,7 @@ def llama_eval( n_tokens: c_int, n_past: c_int, n_threads: c_int, -) -> c_int: +) -> int: return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) @@ -364,7 +364,7 @@ _lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, _lib.llama_tokenize.restype = c_int -def llama_n_vocab(ctx: llama_context_p) -> c_int: +def llama_n_vocab(ctx: llama_context_p) -> int: return _lib.llama_n_vocab(ctx) @@ -372,7 +372,7 @@ _lib.llama_n_vocab.argtypes = [llama_context_p] _lib.llama_n_vocab.restype = c_int -def llama_n_ctx(ctx: llama_context_p) -> c_int: +def llama_n_ctx(ctx: llama_context_p) -> int: return _lib.llama_n_ctx(ctx) @@ -380,7 +380,7 @@ _lib.llama_n_ctx.argtypes = [llama_context_p] _lib.llama_n_ctx.restype = c_int -def llama_n_embd(ctx: llama_context_p) -> c_int: +def llama_n_embd(ctx: llama_context_p) -> int: return _lib.llama_n_embd(ctx) @@ -426,7 +426,7 @@ _lib.llama_token_to_str.restype = c_char_p # Special tokens -def llama_token_bos() -> llama_token: +def llama_token_bos() -> int: return _lib.llama_token_bos() @@ -434,7 +434,7 @@ _lib.llama_token_bos.argtypes = [] _lib.llama_token_bos.restype = llama_token -def llama_token_eos() -> llama_token: +def llama_token_eos() -> int: return _lib.llama_token_eos() @@ -442,7 +442,7 @@ _lib.llama_token_eos.argtypes = [] _lib.llama_token_eos.restype = llama_token -def llama_token_nl() -> llama_token: +def llama_token_nl() -> int: return _lib.llama_token_nl() @@ -625,7 +625,7 @@ def llama_sample_token_mirostat( eta: c_float, m: c_int, mu, # type: _Pointer[c_float] -) -> llama_token: +) -> int: return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu) @@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2( tau: c_float, eta: c_float, mu, # type: _Pointer[c_float] -) -> llama_token: +) -> int: return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu) @@ -669,7 +669,7 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token def llama_sample_token_greedy( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] -) -> llama_token: +) -> int: return _lib.llama_sample_token_greedy(ctx, candidates) @@ -684,7 +684,7 @@ _lib.llama_sample_token_greedy.restype = llama_token def llama_sample_token( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] -) -> llama_token: +) -> int: return _lib.llama_sample_token(ctx, candidates) diff --git a/tests/test_llama.py b/tests/test_llama.py index b3426b8..941287d 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -17,7 +17,7 @@ def test_llama(): # @pytest.mark.skip(reason="need to update sample mocking") def test_llama_patch(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) - n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) + n_vocab = llama_cpp.llama_n_vocab(llama.ctx) ## Set up mock function def mock_eval(*args, **kwargs): @@ -107,7 +107,7 @@ def test_llama_pickle(): def test_utf8(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) - n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) + n_vocab = llama_cpp.llama_n_vocab(llama.ctx) ## Set up mock function def mock_eval(*args, **kwargs):