From e214a584227f0570f90b19c6a8c6efcfa0c8697b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 6 Nov 2023 09:16:36 -0500 Subject: [PATCH] Refactor Llama class internals --- llama_cpp/llama.py | 949 +++++++++++++++++++++++++++++--------------- tests/test_llama.py | 4 +- 2 files changed, 641 insertions(+), 312 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d3b85c9..034d7c0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -208,6 +208,506 @@ class StoppingCriteriaList(List[StoppingCriteria]): return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) +class _LlamaModel: + """Intermediate Python wrapper for a llama.cpp llama_model. + + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free_model = llama_cpp._lib.llama_free_model # type: ignore + + def __init__( + self, + *, + path_model: str, + params: llama_cpp.llama_model_params, + verbose: bool = True, + ): + self.path_model = path_model + self.params = params + self.verbose = verbose + + if not os.path.exists(path_model): + raise ValueError(f"Model path does not exist: {path_model}") + + with suppress_stdout_stderr(disable=self.verbose): + self.model = llama_cpp.llama_load_model_from_file( + self.path_model.encode("utf-8"), self.params + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.model is not None: + self._llama_free_model(self.model) + self.model = None + + def vocab_type(self) -> int: + assert self.model is not None + return llama_cpp.llama_vocab_type(self.model) + + def n_vocab(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_vocab(self.model) + + def n_ctx_train(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_ctx_train(self.model) + + def n_embd(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_embd(self.model) + + def rope_freq_scale_train(self) -> float: + assert self.model is not None + return llama_cpp.llama_rope_freq_scale_train(self.model) + + def desc(self) -> str: + assert self.model is not None + buf = ctypes.create_string_buffer(1024) + llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore + return buf.value.decode("utf-8") + + def size(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_size(self.model) + + def n_params(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_n_params(self.model) + + def get_tensor(self, name: str) -> ctypes.c_void_p: + assert self.model is not None + return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) + + def apply_lora_from_file( + self, + lora_path: str, + scale: float, + path_base_model: Optional[str], + n_threads: int, + ): + assert self.model is not None + return llama_cpp.llama_model_apply_lora_from_file( + self.model, + lora_path.encode("utf-8"), + scale, + path_base_model.encode("utf-8") + if path_base_model is not None + else llama_cpp.c_char_p(0), + n_threads, + ) + + # Vocab + + def token_get_text(self, token: int) -> str: + # TODO: Fix + assert self.model is not None + return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") + + def token_get_score(self, token: int) -> float: + assert self.model is not None + return llama_cpp.llama_token_get_score(self.model, token) + + def token_get_type(self, token: int) -> int: + assert self.model is not None + return llama_cpp.llama_token_get_type(self.model, token) + + # Special tokens + + def token_bos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_bos(self.model) + + def token_eos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eos(self.model) + + def token_nl(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_nl(self.model) + + def token_prefix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_prefix(self.model) + + def token_middle(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_middle(self.model) + + def token_suffix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_suffix(self.model) + + def token_eot(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eot(self.model) + + # Tokenization + + def tokenize(self, text: bytes, add_bos: bool, special: bool): + assert self.model is not None + n_ctx = self.n_ctx_train() + tokens = (llama_cpp.llama_token * n_ctx)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_ctx, add_bos, special + ) + if n_tokens < 0: + n_tokens = abs(n_tokens) + tokens = (llama_cpp.llama_token * n_tokens)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_tokens, add_bos, special + ) + if n_tokens < 0: + raise RuntimeError( + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + ) + return list(tokens[:n_tokens]) + + def token_to_piece(self, token: int) -> bytes: + assert self.model is not None + buf = ctypes.create_string_buffer(32) + llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore + return bytes(buf) + + def detokenize(self, tokens: List[int]) -> bytes: + assert self.model is not None + output = b"" + size = 32 + buffer = (ctypes.c_char * size)() + for token in tokens: + n = llama_cpp.llama_token_to_piece( + self.model, llama_cpp.llama_token(token), buffer, size + ) + assert n <= size + output += bytes(buffer[:n]) + # NOTE: Llama1 models automatically added a space at the start of the prompt + # this line removes a leading space if the first token is a beginning of sentence token + return ( + output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + ) + + @staticmethod + def default_params(): + """Get the default llama_model_params.""" + return llama_cpp.llama_model_default_params() + + +class _LlamaContext: + """Intermediate Python wrapper for a llama.cpp llama_context. + + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free = llama_cpp._lib.llama_free # type: ignore + + def __init__( + self, + *, + model: _LlamaModel, + params: llama_cpp.llama_context_params, + verbose: bool = True, + ): + self.model = model + self.params = params + self.verbose = verbose + + with suppress_stdout_stderr(disable=self.verbose): + self.ctx = llama_cpp.llama_new_context_with_model( + self.model.model, self.params + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.ctx is not None: + self._llama_free(self.ctx) + self.ctx = None + + def n_ctx(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_n_ctx(self.ctx) + + def kv_cache_clear(self): + assert self.ctx is not None + llama_cpp.llama_kv_cache_clear(self.ctx) + + def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) + + def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) + + def kv_cache_seq_keep(self, seq_id: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) + + def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) + + def get_state_size(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_get_state_size(self.ctx) + + # TODO: copy_state_data + + # TODO: set_state_data + + # TODO: llama_load_session_file + + # TODO: llama_save_session_file + + def decode(self, batch: "_LlamaBatch"): + assert self.ctx is not None + assert batch.batch is not None + return_code = llama_cpp.llama_decode( + ctx=self.ctx, + batch=batch.batch, + ) + if return_code != 0: + raise RuntimeError(f"llama_decode returned {return_code}") + + def set_n_threads(self, n_threads: int, n_threads_batch: int): + assert self.ctx is not None + llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) + + def get_logits(self): + assert self.ctx is not None + return llama_cpp.llama_get_logits(self.ctx) + + def get_logits_ith(self, i: int): + assert self.ctx is not None + return llama_cpp.llama_get_logits_ith(self.ctx, i) + + def get_embeddings(self): + assert self.ctx is not None + return llama_cpp.llama_get_embeddings(self.ctx) + + # Sampling functions + + def set_rng_seed(self, seed: int): + assert self.ctx is not None + llama_cpp.llama_set_rng_seed(self.ctx, seed) + + def sample_repetition_penalties( + self, + candidates: "_LlamaTokenDataArray", + last_tokens_data: llama_cpp.Array[llama_cpp.llama_token], + penalty_last_n: int, + penalty_repeat: float, + penalty_freq: float, + penalty_present: float, + ): + assert self.ctx is not None + llama_cpp.llama_sample_repetition_penalties( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + last_tokens_data, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + ) + + def sample_classifier_free_guidance( + self, + candidates: "_LlamaTokenDataArray", + guidance_ctx: "_LlamaContext", + scale: float, + ): + assert self.ctx is not None + assert guidance_ctx.ctx is not None + llama_cpp.llama_sample_classifier_free_guidance( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + guidance_ctx.ctx, + scale, + ) + + def sample_softmax(self, candidates: "_LlamaTokenDataArray"): + assert self.ctx is not None + llama_cpp.llama_sample_softmax( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_k( + self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore + ) + + def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_min_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_tail_free( + self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_tail_free( + self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore + ) + + def sample_typical( + self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_typical( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): + assert self.ctx is not None + llama_cpp.llama_sample_temp( + self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore + ) + + def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_sample_grammar( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + grammar.grammar, + ) + + def sample_token_mirostat( + self, + candidates: "_LlamaTokenDataArray", + tau: float, + eta: float, + m: int, + mu: float, + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + m, + ctypes.pointer(ctypes.c_float(mu)), + ) + + def sample_token_mirostat_v2( + self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat_v2( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + ctypes.pointer(ctypes.c_float(mu)), + ) + + def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_greedy( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + # Grammar + def grammar_accept_token(self, grammar: LlamaGrammar, token: int): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) + + def reset_timings(self): + assert self.ctx is not None + llama_cpp.llama_reset_timings(self.ctx) + + def print_timings(self): + assert self.ctx is not None + llama_cpp.llama_print_timings(self.ctx) + + # Utility functions + @staticmethod + def default_params(): + """Get the default llama_context_params.""" + return llama_cpp.llama_context_default_params() + + +class _LlamaBatch: + _llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore + + def __init__( + self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True + ): + self.n_tokens = n_tokens + self.embd = embd + self.n_seq_max = n_seq_max + self.verbose = verbose + + with suppress_stdout_stderr(disable=self.verbose): + self.batch = llama_cpp.llama_batch_init( + self.n_tokens, self.embd, self.n_seq_max + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.batch is not None: + self._llama_batch_free(self.batch) + self.batch = None + + def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + self.batch.n_tokens = n_tokens + for i in range(n_tokens): + self.batch.token[i] = batch[i] + self.batch.pos[i] = n_past + i + self.batch.seq_id[i][0] = 0 + self.batch.n_seq_id[i] = 1 + self.batch.logits[i] = logits_all + self.batch.logits[n_tokens - 1] = True + + +class _LlamaTokenDataArray: + def __init__(self, *, n_vocab: int): + self.n_vocab = n_vocab + self.candidates_data = np.array( + [], + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + ), + ) + self.candidates_data.resize(3, self.n_vocab, refcheck=False) + self.candidates = llama_cpp.llama_token_data_array( + data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), + size=self.n_vocab, + sorted=False, + ) + self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) + self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) + + def copy_logits(self, logits: npt.NDArray[np.single]): + self.candidates_data["id"][:] = self.default_candidates_data_id + self.candidates_data["logit"][:] = logits + self.candidates_data["p"][:] = self.default_candidates_data_p + self.candidates.data = self.candidates_data.ctypes.data_as( + llama_cpp.llama_token_data_p + ) + self.candidates.sorted = llama_cpp.c_bool(False) + self.candidates.size = llama_cpp.c_size_t(self.n_vocab) + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -312,7 +812,9 @@ class Llama: self._p_tensor_split = None if self.tensor_split is not None: if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES: - raise ValueError(f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}") + raise ValueError( + f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}" + ) # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES self._c_tensor_split = FloatArray( @@ -336,7 +838,9 @@ class Llama: self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_scaling_type = ( - rope_scaling_type if rope_scaling_type is not None else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED + rope_scaling_type + if rope_scaling_type is not None + else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED ) self.context_params.rope_freq_base = ( rope_freq_base if rope_freq_base != 0.0 else 0 @@ -356,9 +860,7 @@ class Llama: self.context_params.yarn_beta_slow = ( yarn_beta_slow if yarn_beta_slow != 0.0 else 0 ) - self.context_params.yarn_orig_ctx = ( - yarn_orig_ctx if yarn_orig_ctx != 0 else 0 - ) + self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv self.context_params.logits_all = logits_all @@ -376,32 +878,28 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - with suppress_stdout_stderr(disable=self.verbose): - self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.model_params - ) - assert self.model is not None + self._model = _LlamaModel( + path_model=self.model_path, params=self.model_params, verbose=self.verbose + ) - with suppress_stdout_stderr(disable=self.verbose): - self.ctx = llama_cpp.llama_new_context_with_model( - self.model, self.context_params - ) + self._ctx = _LlamaContext( + model=self._model, + params=self.context_params, + verbose=self.verbose, + ) - assert self.ctx is not None - - with suppress_stdout_stderr(disable=self.verbose): - self.batch = llama_cpp.llama_batch_init( - self.n_batch, 0, 1 - ) + self._batch = _LlamaBatch( + n_tokens=self.n_batch, + embd=0, + n_seq_max=self.context_params.n_ctx, + verbose=self.verbose, + ) if self.lora_path: - if llama_cpp.llama_model_apply_lora_from_file( - self.model, - self.lora_path.encode("utf-8"), + if self._model.apply_lora_from_file( + self.lora_path, self.lora_scale, - self.lora_base.encode("utf-8") - if self.lora_base is not None - else llama_cpp.c_char_p(0), + self.lora_base, self.n_threads, ): raise RuntimeError( @@ -415,25 +913,11 @@ class Llama: self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() - size = self._n_vocab - sorted = False - self._candidates_data = np.array( - [], - dtype=np.dtype( - [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True - ), - ) - self._candidates_data.resize(3, self._n_vocab, refcheck=False) - candidates = llama_cpp.llama_token_data_array( - data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), - size=size, - sorted=sorted, - ) - self._candidates = candidates + self._token_nl = self.token_nl() self._token_eos = self.token_eos() - self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore - self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) + + self._candidates = _LlamaTokenDataArray(n_vocab=self._n_vocab) self.n_tokens = 0 self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) @@ -441,6 +925,16 @@ class Llama: (n_ctx, self._n_vocab), dtype=np.single ) + @property + def ctx(self) -> llama_cpp.llama_context_p: + assert self._ctx.ctx is not None + return self._ctx.ctx + + @property + def model(self) -> llama_cpp.llama_model_p: + assert self._model.model is not None + return self._model.model + @property def _input_ids(self) -> npt.NDArray[np.intc]: return self.input_ids[: self.n_tokens] @@ -460,7 +954,9 @@ class Llama: maxlen=self._n_ctx if self.context_params.logits_all else 1, ) - def tokenize(self, text: bytes, add_bos: bool = True, special: bool = False) -> List[int]: + def tokenize( + self, text: bytes, add_bos: bool = True, special: bool = False + ) -> List[int]: """Tokenize a string. Args: @@ -472,35 +968,7 @@ class Llama: Returns: A list of tokens. """ - assert self.model is not None - n_ctx = self._n_ctx - tokens = (llama_cpp.llama_token * n_ctx)() - n_tokens = llama_cpp.llama_tokenize( - self.model, - text, - len(text), - tokens, - n_ctx, - add_bos, - special - ) - if n_tokens < 0: - n_tokens = abs(n_tokens) - tokens = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize( - self.model, - text, - len(text), - tokens, - n_tokens, - add_bos, - special - ) - if n_tokens < 0: - raise RuntimeError( - f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' - ) - return list(tokens[:n_tokens]) + return self._model.tokenize(text, add_bos, special) def detokenize(self, tokens: List[int]) -> bytes: """Detokenize a list of tokens. @@ -511,21 +979,7 @@ class Llama: Returns: The detokenized string. """ - assert self.model is not None - output = b"" - size = 32 - buffer = (ctypes.c_char * size)() - for token in tokens: - n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size - ) - assert n <= size - output += bytes(buffer[:n]) - # NOTE: Llama1 models automatically added a space at the start of the prompt - # this line removes a leading space if the first token is a beginning of sentence token - return ( - output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output - ) + return self._model.detokenize(tokens) def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. @@ -545,28 +999,18 @@ class Llama: Args: tokens: The list of tokens to evaluate. """ - assert self.ctx is not None - assert self.batch is not None + assert self._ctx.ctx is not None + assert self._batch.batch is not None n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), len(self._input_ids)) + n_past = min(n_ctx - len(batch), self.n_tokens) n_tokens = len(batch) - llama_cpp.llama_kv_cache_seq_rm(self.ctx, -1, n_past, -1) - self.batch.n_tokens = n_tokens - for i in range(n_tokens): - self.batch.token[i] = batch[i] - self.batch.pos[i] = n_past + i - self.batch.seq_id[i][0] = 0 - self.batch.n_seq_id[i] = 1 - self.batch.logits[i] = True if self.context_params.logits_all else False - self.batch.logits[n_tokens - 1] = True - return_code = llama_cpp.llama_decode( - ctx=self.ctx, - batch=self.batch, + self._ctx.kv_cache_seq_rm(-1, n_past, -1) + self._batch.set_batch( + batch=batch, n_past=n_past, logits_all=self.context_params.logits_all ) - if return_code != 0: - raise RuntimeError(f"llama_decode returned {return_code}") + self._ctx.decode(self._batch) # Save tokens self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch # Save logits @@ -577,144 +1021,10 @@ class Llama: ) # NOTE: Only save the last token logits if logits_all is False self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape( -1 - )[:] = llama_cpp.llama_get_logits(self.ctx)[: rows * cols] + )[:] = self._ctx.get_logits()[: rows * cols] # Update n_tokens self.n_tokens += n_tokens - def _sample( - self, - last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] - last_n_tokens_size: int, - top_k: int, - top_p: float, - temp: float, - tfs_z: float, - repeat_penalty: float, - frequency_penalty: float, - presence_penalty: float, - mirostat_mode: float, - mirostat_tau: float, - mirostat_eta: float, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - ): - assert self.ctx is not None - assert self.n_tokens > 0 - n_vocab = self._n_vocab - n_ctx = self._n_ctx - top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size - logits: npt.NDArray[np.single] = self._scores[-1, :] - - if logits_processor is not None: - logits[:] = logits_processor(self._input_ids, logits) - - nl_logit = logits[self._token_nl] - candidates = self._candidates - candidates_data = self._candidates_data - candidates_data["id"][:] = self._candidates_data_id # type: ignore - candidates_data["logit"][:] = logits - candidates_data["p"][:] = self._candidates_data_p # type: ignore - candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) - candidates.sorted = llama_cpp.c_bool(False) - candidates.size = llama_cpp.c_size_t(n_vocab) - llama_cpp.llama_sample_repetition_penalties( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - last_tokens_data=last_n_tokens_data, - penalty_last_n=last_n_tokens_size, - penalty_repeat=repeat_penalty, - penalty_freq=frequency_penalty, - penalty_present=presence_penalty, - ) - if not penalize_nl: - candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) - - if grammar is not None: - llama_cpp.llama_sample_grammar( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - grammar=grammar.grammar, - ) - - if temp == 0.0: - id = llama_cpp.llama_sample_token_greedy( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - ) - elif mirostat_mode == 1: - mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau) - mirostat_m = llama_cpp.c_int(100) - llama_cpp.llama_sample_temperature( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - temp=temp, - ) - id = llama_cpp.llama_sample_token_mirostat( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - tau=mirostat_tau, - eta=mirostat_eta, - mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore - m=mirostat_m, - ) - elif mirostat_mode == 2: - mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau) - llama_cpp.llama_sample_temperature( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - temp=temp, - ) - id = llama_cpp.llama_sample_token_mirostat_v2( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - tau=mirostat_tau, - eta=mirostat_eta, - mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore - ) - else: - llama_cpp.llama_sample_top_k( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - k=top_k, - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_tail_free( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - z=tfs_z, - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_typical( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - p=llama_cpp.c_float(1.0), - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_top_p( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - p=top_p, - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_temperature( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - temp=temp, - ) - id = llama_cpp.llama_sample_token( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - ) - if grammar is not None: - llama_cpp.llama_grammar_accept_token( - ctx=self.ctx, - grammar=grammar.grammar, - token=llama_cpp.ctypes.c_int(id), - ) - return id - def sample( self, top_k: int = 40, @@ -742,29 +1052,74 @@ class Llama: Returns: The sampled token. """ - assert self.ctx is not None + assert self._ctx is not None + assert self.n_tokens > 0 last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - len(self._input_ids) + 0, self.last_n_tokens_size - self.n_tokens ) + self._input_ids[-self.last_n_tokens_size :].tolist() - return self._sample( - last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( - *last_n_tokens_data - ), - last_n_tokens_size=self.last_n_tokens_size, - top_k=top_k, - top_p=top_p, - temp=temp, - tfs_z=tfs_z, - repeat_penalty=repeat_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - penalize_nl=penalize_nl, - logits_processor=logits_processor, - grammar=grammar, + last_n_tokens_size = len(last_n_tokens_data) + n_vocab = self._n_vocab + n_ctx = self._n_ctx + top_k = n_vocab if top_k <= 0 else top_k + last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size + last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( + *last_n_tokens_data ) + logits: npt.NDArray[np.single] = self._scores[-1, :] + + if logits_processor is not None: + logits[:] = logits_processor(self._input_ids, logits) + + nl_logit = logits[self._token_nl] + self._candidates.copy_logits(logits) + self._ctx.sample_repetition_penalties( + candidates=self._candidates, + last_tokens_data=last_n_tokens_data_c, + penalty_last_n=last_n_tokens_size, + penalty_repeat=repeat_penalty, + penalty_freq=frequency_penalty, + penalty_present=presence_penalty, + ) + if not penalize_nl: + self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( + nl_logit + ) + + if grammar is not None: + self._ctx.sample_grammar( + candidates=self._candidates, + grammar=grammar, + ) + + if temp == 0.0: + id = self._ctx.sample_token_greedy(candidates=self._candidates) + elif mirostat_mode == 1: + self._ctx.sample_temp(candidates=self._candidates, temp=temp) + id = self._ctx.sample_token_mirostat( + candidates=self._candidates, + tau=mirostat_tau, + eta=mirostat_eta, + mu=2.0 * mirostat_tau, + m=100, + ) + elif mirostat_mode == 2: + self._ctx.sample_temp(candidates=self._candidates, temp=temp) + id = self._ctx.sample_token_mirostat_v2( + candidates=self._candidates, + tau=mirostat_tau, + eta=mirostat_eta, + mu=2.0 * mirostat_tau, + ) + else: + self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) + self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) + self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1) + self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) + self._ctx.sample_temp(candidates=self._candidates, temp=temp) + id = self._ctx.sample_token(candidates=self._candidates) + if grammar is not None: + self._ctx.grammar_accept_token(grammar=grammar, token=id) + return id def generate( self, @@ -803,8 +1158,7 @@ class Llama: Yields: The generated tokens. """ - assert self.ctx is not None - if reset and len(self._input_ids) > 0: + if reset and self.n_tokens > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): if a == b: @@ -860,8 +1214,8 @@ class Llama: Returns: An embedding object. """ - assert self.ctx is not None - assert self.model is not None + assert self._ctx.ctx is not None + assert self._model.model is not None model_name: str = model if model is not None else self.model_path if self.context_params.embedding == False: @@ -870,7 +1224,7 @@ class Llama: ) if self.verbose: - llama_cpp.llama_reset_timings(self.ctx) + llama_cpp.llama_reset_timings(self._ctx.ctx) if isinstance(input, str): inputs = [input] @@ -885,8 +1239,8 @@ class Llama: self.eval(tokens) n_tokens = len(tokens) total_tokens += n_tokens - embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.model) + embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[ + : llama_cpp.llama_n_embd(self._model.model) ] data.append( @@ -897,7 +1251,7 @@ class Llama: } ) if self.verbose: - llama_cpp.llama_print_timings(self.ctx) + llama_cpp.llama_print_timings(self._ctx.ctx) return { "object": "list", @@ -944,7 +1298,7 @@ class Llama: logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: - assert self.ctx is not None + assert self._ctx is not None assert suffix is None or suffix.__class__ is str completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -964,16 +1318,16 @@ class Llama: model_name: str = model if model is not None else self.model_path if self.verbose: - llama_cpp.llama_reset_timings(self.ctx) + self._ctx.reset_timings() - if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx): + if len(prompt_tokens) >= self._n_ctx: raise ValueError( - f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" + f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self._ctx)}" ) if max_tokens <= 0: # Unlimited, depending on n_ctx. - max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens) + max_tokens = self._n_ctx - len(prompt_tokens) # Truncate max_tokens if requested tokens would exceed the context window max_tokens = ( @@ -1184,7 +1538,7 @@ class Llama: finish_reason = "stop" if self.verbose: - llama_cpp.llama_print_timings(self.ctx) + self._ctx.print_timings() if stream: remaining_tokens = completion_tokens[returned_tokens:] @@ -1582,24 +1936,6 @@ class Llama: grammar=grammar, ) - def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free): - batch = getattr(self, 'batch', None) - if batch is not None: - _lbatch_free(batch) - self.batch = None - model = getattr(self, 'model', None) - if model is not None: - _lfree_model(model) - self.model = None - ctx = getattr(self, 'ctx', None) - if ctx is not None: - _free(ctx) - self.ctx = None - - def __del__(self): - with suppress_stdout_stderr(disable=self.verbose): - self._free_model() - def __getstate__(self): return dict( model_path=self.model_path, @@ -1684,16 +2020,16 @@ class Llama: ) def save_state(self) -> LlamaState: - assert self.ctx is not None + assert self._ctx.ctx is not None if self.verbose: print("Llama.save_state: saving llama state", file=sys.stderr) - state_size = llama_cpp.llama_get_state_size(self.ctx) + state_size = llama_cpp.llama_get_state_size(self._ctx.ctx) if self.verbose: print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr) llama_state = (llama_cpp.c_uint8 * int(state_size))() if self.verbose: print("Llama.save_state: allocated state", file=sys.stderr) - n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state) + n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state) if self.verbose: print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr) if int(n_bytes) > int(state_size): @@ -1714,7 +2050,7 @@ class Llama: ) def load_state(self, state: LlamaState) -> None: - assert self.ctx is not None + assert self._ctx.ctx is not None self.scores = state.scores.copy() self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens @@ -1722,43 +2058,36 @@ class Llama: LLamaStateArrayType = llama_cpp.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) - if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: + if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size: raise RuntimeError("Failed to set llama state data") def n_ctx(self) -> int: """Return the context window size.""" - assert self.ctx is not None - return llama_cpp.llama_n_ctx(self.ctx) + return self._ctx.n_ctx() def n_embd(self) -> int: """Return the embedding size.""" - assert self.model is not None - return llama_cpp.llama_n_embd(self.model) + return self._model.n_embd() def n_vocab(self) -> int: """Return the vocabulary size.""" - assert self.model is not None - return llama_cpp.llama_n_vocab(self.model) + return self._model.n_vocab() def tokenizer(self) -> "LlamaTokenizer": """Return the tokenizer for this model.""" - assert self.ctx is not None return LlamaTokenizer(self) def token_eos(self) -> int: """Return the end-of-sequence token.""" - assert self.model is not None - return llama_cpp.llama_token_eos(self.model) + return self._model.token_eos() def token_bos(self) -> int: """Return the beginning-of-sequence token.""" - assert self.model is not None - return llama_cpp.llama_token_bos(self.model) + return self._model.token_bos() def token_nl(self) -> int: """Return the newline token.""" - assert self.model is not None - return llama_cpp.llama_token_nl(self.model) + return self._model.token_nl() @staticmethod def logits_to_logprobs(logits: List[float]) -> List[float]: diff --git a/tests/test_llama.py b/tests/test_llama.py index 54f4bd6..5448743 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -8,7 +8,7 @@ def test_llama_cpp_tokenization(): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, verbose=False) assert llama - assert llama.ctx is not None + assert llama._ctx.ctx is not None text = b"Hello World" @@ -37,7 +37,7 @@ def test_llama_cpp_tokenization(): def test_llama_patch(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) - n_vocab = llama_cpp.llama_n_vocab(llama.model) + n_vocab = llama_cpp.llama_n_vocab(llama._model.model) ## Set up mock function def mock_eval(*args, **kwargs):