from __future__ import annotations import os import ctypes from typing import ( List, Optional, Sequence, ) from dataclasses import dataclass, field import numpy as np import numpy.typing as npt from .llama_types import * from .llama_grammar import LlamaGrammar import llama_cpp.llama_cpp as llama_cpp from ._utils import suppress_stdout_stderr # Python wrappers over llama.h structs class _LlamaModel: """Intermediate Python wrapper for a llama.cpp llama_model. NOTE: For stability it's recommended you use the Llama class instead.""" _llama_free_model = None # NOTE: this must be "saved" here to avoid exceptions when calling __del__ _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, *, path_model: str, params: llama_cpp.llama_model_params, verbose: bool = True, ): self.path_model = path_model self.params = params self.verbose = verbose self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore if not os.path.exists(path_model): raise ValueError(f"Model path does not exist: {path_model}") with self._suppress_stdout_stderr(disable=self.verbose): self.model = llama_cpp.llama_load_model_from_file( self.path_model.encode("utf-8"), self.params ) def __del__(self): with self._suppress_stdout_stderr(disable=self.verbose): if self.model is not None and self._llama_free_model is not None: self._llama_free_model(self.model) self.model = None def vocab_type(self) -> int: assert self.model is not None return llama_cpp.llama_vocab_type(self.model) def n_vocab(self) -> int: assert self.model is not None return llama_cpp.llama_n_vocab(self.model) def n_ctx_train(self) -> int: assert self.model is not None return llama_cpp.llama_n_ctx_train(self.model) def n_embd(self) -> int: assert self.model is not None return llama_cpp.llama_n_embd(self.model) def rope_freq_scale_train(self) -> float: assert self.model is not None return llama_cpp.llama_rope_freq_scale_train(self.model) def desc(self) -> str: assert self.model is not None buf = ctypes.create_string_buffer(1024) llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore return buf.value.decode("utf-8") def size(self) -> int: assert self.model is not None return llama_cpp.llama_model_size(self.model) def n_params(self) -> int: assert self.model is not None return llama_cpp.llama_model_n_params(self.model) def get_tensor(self, name: str) -> ctypes.c_void_p: assert self.model is not None return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) def apply_lora_from_file( self, lora_path: str, scale: float, path_base_model: Optional[str], n_threads: int, ): assert self.model is not None return llama_cpp.llama_model_apply_lora_from_file( self.model, lora_path.encode("utf-8"), scale, path_base_model.encode("utf-8") if path_base_model is not None else llama_cpp.c_char_p(0), n_threads, ) # Vocab def token_get_text(self, token: int) -> str: # TODO: Fix assert self.model is not None return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") def token_get_score(self, token: int) -> float: assert self.model is not None return llama_cpp.llama_token_get_score(self.model, token) def token_get_type(self, token: int) -> int: assert self.model is not None return llama_cpp.llama_token_get_type(self.model, token) # Special tokens def token_bos(self) -> int: assert self.model is not None return llama_cpp.llama_token_bos(self.model) def token_eos(self) -> int: assert self.model is not None return llama_cpp.llama_token_eos(self.model) def token_nl(self) -> int: assert self.model is not None return llama_cpp.llama_token_nl(self.model) def token_prefix(self) -> int: assert self.model is not None return llama_cpp.llama_token_prefix(self.model) def token_middle(self) -> int: assert self.model is not None return llama_cpp.llama_token_middle(self.model) def token_suffix(self) -> int: assert self.model is not None return llama_cpp.llama_token_suffix(self.model) def token_eot(self) -> int: assert self.model is not None return llama_cpp.llama_token_eot(self.model) # Tokenization def tokenize(self, text: bytes, add_bos: bool, special: bool): assert self.model is not None n_ctx = self.n_ctx_train() tokens = (llama_cpp.llama_token * n_ctx)() n_tokens = llama_cpp.llama_tokenize( self.model, text, len(text), tokens, n_ctx, add_bos, special ) if n_tokens < 0: n_tokens = abs(n_tokens) tokens = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( self.model, text, len(text), tokens, n_tokens, add_bos, special ) if n_tokens < 0: raise RuntimeError( f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' ) return list(tokens[:n_tokens]) def token_to_piece(self, token: int) -> bytes: assert self.model is not None buf = ctypes.create_string_buffer(32) llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore return bytes(buf) def detokenize(self, tokens: List[int]) -> bytes: assert self.model is not None output = b"" size = 32 buffer = (ctypes.c_char * size)() for token in tokens: n = llama_cpp.llama_token_to_piece( self.model, llama_cpp.llama_token(token), buffer, size ) assert n <= size output += bytes(buffer[:n]) # NOTE: Llama1 models automatically added a space at the start of the prompt # this line removes a leading space if the first token is a beginning of sentence token return ( output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output ) # Extra def metadata(self) -> Dict[str, str]: assert self.model is not None metadata: Dict[str, str] = {} buffer_size = 1024 buffer = ctypes.create_string_buffer(buffer_size) # zero the buffer buffer.value = b'\0' * buffer_size # iterate over model keys for i in range(llama_cpp.llama_model_meta_count(self.model)): nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: buffer_size = nbytes buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) key = buffer.value.decode("utf-8") nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: buffer_size = nbytes buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) value = buffer.value.decode("utf-8") metadata[key] = value return metadata @staticmethod def default_params(): """Get the default llama_model_params.""" return llama_cpp.llama_model_default_params() class _LlamaContext: """Intermediate Python wrapper for a llama.cpp llama_context. NOTE: For stability it's recommended you use the Llama class instead.""" _llama_free = None # NOTE: this must be "saved" here to avoid exceptions when calling __del__ _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, *, model: _LlamaModel, params: llama_cpp.llama_context_params, verbose: bool = True, ): self.model = model self.params = params self.verbose = verbose self._llama_free = llama_cpp._lib.llama_free # type: ignore with self._suppress_stdout_stderr(disable=self.verbose): self.ctx = llama_cpp.llama_new_context_with_model( self.model.model, self.params ) def __del__(self): with self._suppress_stdout_stderr(disable=self.verbose): if self.ctx is not None and self._llama_free is not None: self._llama_free(self.ctx) self.ctx = None def n_ctx(self) -> int: assert self.ctx is not None return llama_cpp.llama_n_ctx(self.ctx) def kv_cache_clear(self): assert self.ctx is not None llama_cpp.llama_kv_cache_clear(self.ctx) def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): assert self.ctx is not None llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): assert self.ctx is not None llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) def kv_cache_seq_keep(self, seq_id: int): assert self.ctx is not None llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): assert self.ctx is not None llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) def get_state_size(self) -> int: assert self.ctx is not None return llama_cpp.llama_get_state_size(self.ctx) # TODO: copy_state_data # TODO: set_state_data # TODO: llama_load_session_file # TODO: llama_save_session_file def decode(self, batch: "_LlamaBatch"): assert self.ctx is not None assert batch.batch is not None return_code = llama_cpp.llama_decode( ctx=self.ctx, batch=batch.batch, ) if return_code != 0: raise RuntimeError(f"llama_decode returned {return_code}") def set_n_threads(self, n_threads: int, n_threads_batch: int): assert self.ctx is not None llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) def get_logits(self): assert self.ctx is not None return llama_cpp.llama_get_logits(self.ctx) def get_logits_ith(self, i: int): assert self.ctx is not None return llama_cpp.llama_get_logits_ith(self.ctx, i) def get_embeddings(self): assert self.ctx is not None return llama_cpp.llama_get_embeddings(self.ctx) # Sampling functions def set_rng_seed(self, seed: int): assert self.ctx is not None llama_cpp.llama_set_rng_seed(self.ctx, seed) def sample_repetition_penalties( self, candidates: "_LlamaTokenDataArray", last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, ): assert self.ctx is not None llama_cpp.llama_sample_repetition_penalties( self.ctx, ctypes.byref(candidates.candidates), # type: ignore last_tokens_data, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, ) def sample_classifier_free_guidance( self, candidates: "_LlamaTokenDataArray", guidance_ctx: "_LlamaContext", scale: float, ): assert self.ctx is not None assert guidance_ctx.ctx is not None llama_cpp.llama_sample_classifier_free_guidance( self.ctx, ctypes.byref(candidates.candidates), # type: ignore guidance_ctx.ctx, scale, ) def sample_softmax(self, candidates: "_LlamaTokenDataArray"): assert self.ctx is not None llama_cpp.llama_sample_softmax( self.ctx, ctypes.byref(candidates.candidates), # type: ignore ) def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): assert self.ctx is not None llama_cpp.llama_sample_top_k( self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore ) def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): assert self.ctx is not None llama_cpp.llama_sample_top_p( self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore ) def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): assert self.ctx is not None llama_cpp.llama_sample_min_p( self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore ) def sample_tail_free( self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int ): assert self.ctx is not None llama_cpp.llama_sample_tail_free( self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore ) def sample_typical( self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int ): assert self.ctx is not None llama_cpp.llama_sample_typical( self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore ) def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): assert self.ctx is not None llama_cpp.llama_sample_temp( self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore ) def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): assert self.ctx is not None assert grammar.grammar is not None llama_cpp.llama_sample_grammar( self.ctx, ctypes.byref(candidates.candidates), # type: ignore grammar.grammar, ) def sample_token_mirostat( self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, m: int, mu: ctypes._Pointer[ctypes.c_float], # type: ignore ) -> int: assert self.ctx is not None return llama_cpp.llama_sample_token_mirostat( self.ctx, ctypes.byref(candidates.candidates), # type: ignore tau, eta, m, mu, ) def sample_token_mirostat_v2( self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore ) -> int: assert self.ctx is not None return llama_cpp.llama_sample_token_mirostat_v2( self.ctx, ctypes.byref(candidates.candidates), # type: ignore tau, eta, mu, ) def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: assert self.ctx is not None return llama_cpp.llama_sample_token_greedy( self.ctx, ctypes.byref(candidates.candidates), # type: ignore ) def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: assert self.ctx is not None return llama_cpp.llama_sample_token( self.ctx, ctypes.byref(candidates.candidates), # type: ignore ) # Grammar def grammar_accept_token(self, grammar: LlamaGrammar, token: int): assert self.ctx is not None assert grammar.grammar is not None llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) def reset_timings(self): assert self.ctx is not None llama_cpp.llama_reset_timings(self.ctx) def print_timings(self): assert self.ctx is not None llama_cpp.llama_print_timings(self.ctx) # Utility functions @staticmethod def default_params(): """Get the default llama_context_params.""" return llama_cpp.llama_context_default_params() class _LlamaBatch: _llama_batch_free = None # NOTE: this must be "saved" here to avoid exceptions when calling __del__ _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True ): self.n_tokens = n_tokens self.embd = embd self.n_seq_max = n_seq_max self.verbose = verbose self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore with self._suppress_stdout_stderr(disable=self.verbose): self.batch = llama_cpp.llama_batch_init( self.n_tokens, self.embd, self.n_seq_max ) def __del__(self): with self._suppress_stdout_stderr(disable=self.verbose): if self.batch is not None and self._llama_batch_free is not None: self._llama_batch_free(self.batch) self.batch = None def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): assert self.batch is not None n_tokens = len(batch) self.batch.n_tokens = n_tokens for i in range(n_tokens): self.batch.token[i] = batch[i] self.batch.pos[i] = n_past + i self.batch.seq_id[i][0] = 0 self.batch.n_seq_id[i] = 1 self.batch.logits[i] = logits_all self.batch.logits[n_tokens - 1] = True class _LlamaTokenDataArray: def __init__(self, *, n_vocab: int): self.n_vocab = n_vocab self.candidates_data = np.array( [], dtype=np.dtype( [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True ), ) self.candidates_data.resize(3, self.n_vocab, refcheck=False) self.candidates = llama_cpp.llama_token_data_array( data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), size=self.n_vocab, sorted=False, ) self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) def copy_logits(self, logits: npt.NDArray[np.single]): self.candidates_data["id"][:] = self.default_candidates_data_id self.candidates_data["logit"][:] = logits self.candidates_data["p"][:] = self.default_candidates_data_p self.candidates.data = self.candidates_data.ctypes.data_as( llama_cpp.llama_token_data_p ) self.candidates.sorted = llama_cpp.c_bool(False) self.candidates.size = llama_cpp.c_size_t(self.n_vocab) # Python wrappers over common/common def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: n_tokens = len(text) + 1 if add_bos else len(text) result = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( model.model, text.encode("utf-8"), len(text), result, n_tokens, add_bos, special, ) if n_tokens < 0: result = (llama_cpp.llama_token * -n_tokens)() check = llama_cpp.llama_tokenize( model.model, text.encode("utf-8"), len(text), result, len(result), add_bos, special, ) if check != -n_tokens: raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') else: result = result[:n_tokens] return list(result) def _token_to_piece(model: _LlamaModel, token: int) -> str: assert model.model is not None result = (ctypes.c_char * 8)(0) n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) if n_tokens < 0: result = (ctypes.c_char * -n_tokens)(0) check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) if check != -n_tokens: raise RuntimeError(f"Failed to get piece: token={token}") else: result = result[:n_tokens] return bytes(result).decode("utf-8") def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: bos_id = model.token_bos() result = "" for i, token in enumerate(tokens): piece = _token_to_piece(model, token) if ( (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) ) and piece[0] == " ": piece = piece[1:] result += piece return result def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: result = "" for token in tokens: piece = _token_to_piece(model, token) result += piece return result def _should_add_bos(model: _LlamaModel) -> bool: assert model.model is not None add_bos = llama_cpp.llama_add_bos_token(model.model) if add_bos != -1: return add_bos != 0 else: return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM # Python wrappers over common/sampling structs @dataclass class _LlamaSamplingParams: n_prev: int = 64 n_probs: int = 0 top_k: int = 40 top_p: float = 0.95 min_p: float = 0.05 tfs_z: float = 1.00 typical_p: float = 1.00 temp: float = 0.80 penalty_last_n: int = 64 penalty_repeat: float = 1.10 penalty_freq: float = 0.00 penalty_present: float = 0.00 mirostat: int = 0 mirostat_tau: float = 5.00 mirostat_eta: float = 0.10 penalize_nl: bool = True grammar: str = "" cfg_negative_prompt: str = "" cfg_scale: float = 1.00 logit_bias: dict[int, float] = field(default_factory=dict) @dataclass class _LlamaSamplingContext: params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) grammar: Optional[LlamaGrammar] = None # NOTE: Missing parsed_grammar prev: list[int] = field(default_factory=list) cur: list[llama_cpp.llama_token_data] = field(default_factory=list) def reset(self): self.prev = [] self.cur = [] if self.grammar is not None: self.grammar.reset() def cp(self): return _LlamaSamplingContext( params=self.params, mirostat_mu=self.mirostat_mu, grammar=self.grammar, prev=self.prev.copy(), cur=self.cur.copy(), ) def last(self) -> Optional[int]: if len(self.prev) > 0: return self.prev[-1] else: return None def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") def sample( self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None ): n_vocab = ctx_main.model.n_vocab() id: int = 0 if logits_array is None: logits = ctx_main.get_logits_ith(idx) logits_array = np.array( ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, dtype=np.single, ) # apply logit_bias for token, logit_bias in self.params.logit_bias.items(): logits_array[token] += logit_bias token_data_array = _LlamaTokenDataArray( n_vocab=n_vocab ) # TODO: Only create this once token_data_array.copy_logits(logits_array) if ctx_cfg is not None: ctx_main.sample_classifier_free_guidance( token_data_array, ctx_cfg, self.params.cfg_scale ) # apply penalties if len(self.prev) > 0: nl_token = ctx_main.model.token_nl() nl_logit = logits_array[nl_token] if self.params.penalty_last_n > 0: ctx_main.sample_repetition_penalties( token_data_array, # TODO: Only create this once (llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore self.params.penalty_last_n, self.params.penalty_repeat, self.params.penalty_freq, self.params.penalty_present, ) if not self.params.penalize_nl: token_data_array.candidates_data["logit"][nl_token] = nl_logit if self.grammar is not None: ctx_main.sample_grammar(token_data_array, self.grammar) if self.params.temp < 0: ctx_main.sample_softmax(token_data_array) id = token_data_array.candidates_data["id"][0] elif self.params.temp == 0: id = ctx_main.sample_token_greedy(token_data_array) else: if self.params.mirostat == 1: mirostat_m = 100 ctx_main.sample_temp(token_data_array, self.params.temp) id = ctx_main.sample_token_mirostat( token_data_array, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, ctypes.pointer(self.mirostat_mu), ) elif self.params.mirostat == 2: ctx_main.sample_temp(token_data_array, self.params.temp) id = ctx_main.sample_token_mirostat_v2( token_data_array, self.params.mirostat_tau, self.params.mirostat_eta, ctypes.pointer(self.mirostat_mu), ) else: min_keep = max(1, self.params.n_probs) ctx_main.sample_top_k( token_data_array, self.params.top_k, min_keep=min_keep ) ctx_main.sample_tail_free( token_data_array, self.params.tfs_z, min_keep=min_keep ) ctx_main.sample_typical( token_data_array, self.params.typical_p, min_keep=min_keep ) ctx_main.sample_top_p( token_data_array, self.params.top_p, min_keep=min_keep ) ctx_main.sample_min_p( token_data_array, self.params.min_p, min_keep=min_keep ) ctx_main.sample_temp(token_data_array, self.params.temp) id = ctx_main.sample_token(token_data_array) return id def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): if apply_grammar and self.grammar is not None: ctx_main.grammar_accept_token(self.grammar, id) self.prev.append(id)