diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca220e..df635fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.78] + +### Added + +- Grammar based sampling via LlamaGrammar which can be passed to completions +- Make n_gpu_layers == -1 offload all layers + ## [0.1.77] - (llama.cpp) Update llama.cpp add support for LLaMa 2 70B diff --git a/README.md b/README.md index 639d261..21ff0ed 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ llm = Llama(model_path="./models/7B/ggml-model.bin", n_ctx=2048) Llama2 70b must set the `n_gqa` parameter (grouped-query attention factor) to 8 when loading: ```python -llm = Llama(model_path="./models/7B/ggml-model.bin", n_gqa=8) +llm = Llama(model_path="./models/70B/ggml-model.bin", n_gqa=8) ``` ## Web Server @@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm ## Low-level API The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`. -The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h). +The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h). Below is a short example demonstrating how to use the low-level API to tokenize a prompt: diff --git a/docker/openblas_simple/Dockerfile b/docker/openblas_simple/Dockerfile index 8231bdb..020c34d 100644 --- a/docker/openblas_simple/Dockerfile +++ b/docker/openblas_simple/Dockerfile @@ -9,7 +9,7 @@ COPY . . RUN apt update && apt install -y libopenblas-dev ninja-build build-essential RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings -RUN LLAMA_OPENBLAS=1 pip install llama_cpp_python --verbose +RUN CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama_cpp_python --verbose # Run the server CMD python3 -m llama_cpp.server diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 94ab8c5..036d833 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -23,10 +23,12 @@ import ctypes from . import llama_cpp from .llama_types import * +from .llama_grammar import LlamaGrammar import numpy as np import numpy.typing as npt +from .utils import suppress_stdout_stderr class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -231,7 +233,8 @@ class Llama: rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b - rms_norm_eps: Optional[float] = None, # (TEMPORARY) + rms_norm_eps: Optional[float] = None, # (TEMPORARY) + mul_mat_q: Optional[bool] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -241,6 +244,7 @@ class Llama: n_ctx: Maximum context size. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. seed: Random seed. -1 for random. + n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded. f16_kv: Use half-precision for key/value cache. logits_all: Return logits for all tokens, not just the last token. vocab_only: Only load the vocabulary no weights. @@ -269,7 +273,7 @@ class Llama: self.params = llama_cpp.llama_context_default_params() self.params.n_ctx = n_ctx - self.params.n_gpu_layers = n_gpu_layers + self.params.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers self.params.seed = seed self.params.f16_kv = f16_kv self.params.logits_all = logits_all @@ -280,7 +284,7 @@ class Llama: self.params.low_vram = low_vram self.tensor_split = tensor_split - self._c_tensor_split = None + self._p_tensor_split = None if self.tensor_split is not None: # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES @@ -299,6 +303,9 @@ class Llama: if rms_norm_eps is not None: self.params.rms_norm_eps = rms_norm_eps + if mul_mat_q is not None: + self.params.mul_mat_q = mul_mat_q + self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) @@ -316,12 +323,25 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.params - ) + if verbose: + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) + else: + with suppress_stdout_stderr(): + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) assert self.model is not None - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + if verbose: + self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + else: + with suppress_stdout_stderr(): + print("here") + self.ctx = llama_cpp.llama_new_context_with_model( + self.model, self.params + ) assert self.ctx is not None @@ -358,8 +378,8 @@ class Llama: sorted=sorted, ) self._candidates = candidates - self._token_nl = Llama.token_nl() - self._token_eos = Llama.token_eos() + self._token_nl = self.token_nl() + self._token_eos = self.token_eos() self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) @@ -437,10 +457,14 @@ class Llama: """ assert self.ctx is not None output = b"" + buffer_size = 32 + buffer = (ctypes.c_char * buffer_size)() for token in tokens: - output += llama_cpp.llama_token_to_str( - self.ctx, llama_cpp.llama_token(token) + n = llama_cpp.llama_token_to_str( + self.ctx, llama_cpp.llama_token(token), buffer, buffer_size ) + assert n <= buffer_size + output += bytes(buffer[:n]) return output def set_cache(self, cache: Optional[BaseLlamaCache]): @@ -506,6 +530,7 @@ class Llama: mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): assert self.ctx is not None assert self.n_tokens > 0 @@ -548,8 +573,16 @@ class Llama: ) if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) + + if grammar is not None: + llama_cpp.llama_sample_grammar( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + grammar=grammar.grammar, + ) + if temp.value == 0.0: - return llama_cpp.llama_sample_token_greedy( + id = llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) @@ -561,7 +594,7 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat( + id = llama_cpp.llama_sample_token_mirostat( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -576,7 +609,7 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat_v2( + id = llama_cpp.llama_sample_token_mirostat_v2( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -613,10 +646,17 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token( + id = llama_cpp.llama_sample_token( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + if grammar is not None: + llama_cpp.llama_grammar_accept_token( + ctx=self.ctx, + grammar=grammar.grammar, + token=llama_cpp.ctypes.c_int(id), + ) + return id def sample( self, @@ -632,6 +672,7 @@ class Llama: mirostat_tau: float = 5.0, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): """Sample a token from the model. @@ -665,6 +706,7 @@ class Llama: mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, logits_processor=logits_processor, + grammar=grammar, ) def generate( @@ -683,6 +725,7 @@ class Llama: mirostat_eta: float = 0.1, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -704,7 +747,6 @@ class Llama: The generated tokens. """ assert self.ctx is not None - if reset and len(self._input_ids) > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -722,6 +764,9 @@ class Llama: if reset: self.reset() + if grammar is not None: + grammar.reset() + while True: self.eval(tokens) token = self.sample( @@ -736,6 +781,7 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, logits_processor=logits_processor, + grammar=grammar, ) if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] @@ -838,6 +884,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -915,6 +962,7 @@ class Llama: repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -965,9 +1013,7 @@ class Llama: for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position >= ( - remaining_length - first_stop_position - ): + if token_end_position >= (remaining_length - first_stop_position): break logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: @@ -1261,6 +1307,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1305,6 +1352,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1334,6 +1382,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1378,6 +1427,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ) def _convert_text_completion_to_chat( @@ -1460,6 +1510,7 @@ class Llama: mirostat_eta: float = 0.1, model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1502,6 +1553,7 @@ class Llama: mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, + grammar=grammar, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore @@ -1511,10 +1563,10 @@ class Llama: return self._convert_text_completion_to_chat(completion) def __del__(self): - if self.model is not None: + if hasattr(self, "model") and self.model is not None: llama_cpp.llama_free_model(self.model) self.model = None - if self.ctx is not None: + if hasattr(self, "ctx") and self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None @@ -1638,20 +1690,20 @@ class Llama: assert self.ctx is not None return LlamaTokenizer(self) - @staticmethod - def token_eos() -> int: + def token_eos(self) -> int: """Return the end-of-sequence token.""" - return llama_cpp.llama_token_eos() + assert self.ctx is not None + return llama_cpp.llama_token_eos(self.ctx) - @staticmethod - def token_bos() -> int: + def token_bos(self) -> int: """Return the beginning-of-sequence token.""" - return llama_cpp.llama_token_bos() + assert self.ctx is not None + return llama_cpp.llama_token_bos(self.ctx) - @staticmethod - def token_nl() -> int: + def token_nl(self) -> int: """Return the newline token.""" - return llama_cpp.llama_token_nl() + assert self.ctx is not None + return llama_cpp.llama_token_nl(self.ctx) @staticmethod def logits_to_logprobs(logits: List[float]) -> List[float]: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 0f319fc..0332577 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -90,26 +90,17 @@ GGML_USE_CUBLAS = hasattr(_lib, "ggml_init_cublas") GGML_CUDA_MAX_DEVICES = 16 LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1 -# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' -LLAMA_FILE_MAGIC_GGJT = 0x67676A74 -# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' -LLAMA_FILE_MAGIC_GGLA = 0x67676C61 -# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf' -LLAMA_FILE_MAGIC_GGMF = 0x67676D66 -# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' -LLAMA_FILE_MAGIC_GGML = 0x67676D6C -# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' -LLAMA_FILE_MAGIC_GGSN = 0x6767736E +# define LLAMA_DEFAULT_SEED 0xFFFFFFFF +LLAMA_DEFAULT_SEED = ctypes.c_int(0xFFFFFFFF) -# #define LLAMA_FILE_VERSION 3 -LLAMA_FILE_VERSION = 3 -LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT -LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML +# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E) + +# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN -LLAMA_SESSION_VERSION = 1 +# define LLAMA_SESSION_VERSION 1 +LLAMA_SESSION_VERSION = ctypes.c_int(1) -# #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -LLAMA_DEFAULT_SEED = 0xFFFFFFFF # struct llama_model; llama_model_p = c_void_p @@ -122,6 +113,82 @@ llama_context_p = c_void_p llama_token = c_int llama_token_p = POINTER(llama_token) +# enum llama_log_level { +# LLAMA_LOG_LEVEL_ERROR = 2, +# LLAMA_LOG_LEVEL_WARN = 3, +# LLAMA_LOG_LEVEL_INFO = 4 +# }; +LLAMA_LOG_LEVEL_ERROR = c_int(2) +LLAMA_LOG_LEVEL_WARN = c_int(3) +LLAMA_LOG_LEVEL_INFO = c_int(4) + +# enum llama_vocab_type { +# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece +# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding +# }; +LLAMA_VOCAB_TYPE_SPM = c_int(0) +LLAMA_VOCAB_TYPE_BPE = c_int(1) + + +# enum llama_token_type { +# LLAMA_TOKEN_TYPE_UNDEFINED = 0, +# LLAMA_TOKEN_TYPE_NORMAL = 1, +# LLAMA_TOKEN_TYPE_UNKNOWN = 2, +# LLAMA_TOKEN_TYPE_CONTROL = 3, +# LLAMA_TOKEN_TYPE_USER_DEFINED = 4, +# LLAMA_TOKEN_TYPE_UNUSED = 5, +# LLAMA_TOKEN_TYPE_BYTE = 6, +# }; +LLAMA_TOKEN_TYPE_UNDEFINED = c_int(0) +LLAMA_TOKEN_TYPE_NORMAL = c_int(1) +LLAMA_TOKEN_TYPE_UNKNOWN = c_int(2) +LLAMA_TOKEN_TYPE_CONTROL = c_int(3) +LLAMA_TOKEN_TYPE_USER_DEFINED = c_int(4) +LLAMA_TOKEN_TYPE_UNUSED = c_int(5) +LLAMA_TOKEN_TYPE_BYTE = c_int(6) + +# enum llama_ftype { +# LLAMA_FTYPE_ALL_F32 = 0, +# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 +# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed +# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed +# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors +# +# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file +# }; +LLAMA_FTYPE_ALL_F32 = c_int(0) +LLAMA_FTYPE_MOSTLY_F16 = c_int(1) +LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) +LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) +LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4) +LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) +LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) +LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) +LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10) +LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11) +LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12) +LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13) +LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14) +LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15) +LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16) +LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17) +LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18) +LLAMA_FTYPE_GUESSED = c_int(1024) + # typedef struct llama_token_data { # llama_token id; // token id @@ -157,16 +224,13 @@ llama_token_data_array_p = POINTER(llama_token_data_array) # typedef void (*llama_progress_callback)(float progress, void *ctx); llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) - # struct llama_context_params { # uint32_t seed; // RNG seed, -1 for random # int32_t n_ctx; // text context # int32_t n_batch; // prompt processing batch size -# int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) -# float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams) # int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t main_gpu; // the GPU that is used for scratch and small tensors -# + # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) # // ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -181,6 +245,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // Keep the booleans together to avoid misalignment during copy-by-value. # bool low_vram; // if true, reduce VRAM usage at the cost of performance +# bool mul_mat_q; // if true, use experimental mul_mat_q kernels # bool f16_kv; // use fp16 for KV cache # bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool vocab_only; // only load the vocabulary, no weights @@ -193,16 +258,15 @@ class llama_context_params(Structure): ("seed", c_uint32), ("n_ctx", c_int32), ("n_batch", c_int32), - ("n_gqa", c_int32), - ("rms_norm_eps", c_float), ("n_gpu_layers", c_int32), ("main_gpu", c_int32), - ("tensor_split", POINTER(c_float)), + ("tensor_split", c_float_p), ("rope_freq_base", c_float), ("rope_freq_scale", c_float), ("progress_callback", llama_progress_callback), ("progress_callback_user_data", c_void_p), ("low_vram", c_bool), + ("mul_mat_q", c_bool), ("f16_kv", c_bool), ("logits_all", c_bool), ("vocab_only", c_bool), @@ -214,50 +278,20 @@ class llama_context_params(Structure): llama_context_params_p = POINTER(llama_context_params) -# enum llama_ftype { -# LLAMA_FTYPE_ALL_F32 = 0, -# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 -# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed -# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed -# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors -# }; -LLAMA_FTYPE_ALL_F32 = 0 -LLAMA_FTYPE_MOSTLY_F16 = 1 -LLAMA_FTYPE_MOSTLY_Q4_0 = 2 -LLAMA_FTYPE_MOSTLY_Q4_1 = 3 -LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4 -LLAMA_FTYPE_MOSTLY_Q8_0 = 7 -LLAMA_FTYPE_MOSTLY_Q5_0 = 8 -LLAMA_FTYPE_MOSTLY_Q5_1 = 9 -LLAMA_FTYPE_MOSTLY_Q2_K = 10 -LLAMA_FTYPE_MOSTLY_Q3_K_S = 11 -LLAMA_FTYPE_MOSTLY_Q3_K_M = 12 -LLAMA_FTYPE_MOSTLY_Q3_K_L = 13 -LLAMA_FTYPE_MOSTLY_Q4_K_S = 14 -LLAMA_FTYPE_MOSTLY_Q4_K_M = 15 -LLAMA_FTYPE_MOSTLY_Q5_K_S = 16 -LLAMA_FTYPE_MOSTLY_Q5_K_M = 17 -LLAMA_FTYPE_MOSTLY_Q6_K = 18 + +# // Signature for logging events +# // Note that text includes the new line character at the end for most events. +# // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it +# // if it exists. +# // It might not exist for progress report where '.' is output repeatedly. +# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); +llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p) # // model quantization parameters # typedef struct llama_model_quantize_params { # int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() -# enum llama_ftype ftype; // quantize to this llama_ftype +# enum llama_ftype ftype; // quantize to this llama_ftype # bool allow_requantize; // allow quantizing non-f32/f16 tensors # bool quantize_output_tensor; // quantize output.weight # } llama_model_quantize_params; @@ -349,16 +383,7 @@ class llama_timings(Structure): ] -# LLAMA_API int llama_max_devices(); -def llama_max_devices() -> int: - return _lib.llama_max_devices() - - -_lib.llama_max_devices.argtypes = [] -_lib.llama_max_devices.restype = c_int - - -# LLAMA_API struct llama_context_params llama_context_default_params(); +# LLAMA_API struct llama_context_params llama_context_default_params(void); def llama_context_default_params() -> llama_context_params: return _lib.llama_context_default_params() @@ -367,7 +392,7 @@ _lib.llama_context_default_params.argtypes = [] _lib.llama_context_default_params.restype = llama_context_params -# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); +# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); def llama_model_quantize_default_params() -> llama_model_quantize_params: return _lib.llama_model_quantize_default_params() @@ -376,25 +401,6 @@ _lib.llama_model_quantize_default_params.argtypes = [] _lib.llama_model_quantize_default_params.restype = llama_model_quantize_params -# LLAMA_API bool llama_mmap_supported(); -def llama_mmap_supported() -> bool: - return _lib.llama_mmap_supported() - - -_lib.llama_mmap_supported.argtypes = [] -_lib.llama_mmap_supported.restype = c_bool - - -# LLAMA_API bool llama_mlock_supported(); -def llama_mlock_supported() -> bool: - return _lib.llama_mlock_supported() - - -_lib.llama_mlock_supported.argtypes = [] -_lib.llama_mlock_supported.restype = c_bool - - -# // TODO: not great API - very likely to change # // Initialize the llama + ggml backend # // If numa is true, use NUMA optimizations # // Call once at the start of the program @@ -408,7 +414,7 @@ _lib.llama_backend_init.restype = None # // Call once at the end of the program - currently only used for MPI -# LLAMA_API void llama_backend_free(); +# LLAMA_API void llama_backend_free(void); def llama_backend_free(): return _lib.llama_backend_free() @@ -418,7 +424,7 @@ _lib.llama_backend_free.restype = None # LLAMA_API struct llama_model * llama_load_model_from_file( -# const char * path_model, +# const char * path_model, # struct llama_context_params params); def llama_load_model_from_file( path_model: bytes, params: llama_context_params @@ -440,7 +446,7 @@ _lib.llama_free_model.restype = None # LLAMA_API struct llama_context * llama_new_context_with_model( -# struct llama_model * model, +# struct llama_model * model, # struct llama_context_params params); def llama_new_context_with_model( model: llama_model_p, params: llama_context_params @@ -452,7 +458,17 @@ _lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_param _lib.llama_new_context_with_model.restype = llama_context_p -# LLAMA_API int64_t llama_time_us(); +# // Frees all allocated memory +# LLAMA_API void llama_free(struct llama_context * ctx); +def llama_free(ctx: llama_context_p): + return _lib.llama_free(ctx) + + +_lib.llama_free.argtypes = [llama_context_p] +_lib.llama_free.restype = None + + +# LLAMA_API int64_t llama_time_us(void); def llama_time_us() -> int: return _lib.llama_time_us() @@ -461,30 +477,95 @@ _lib.llama_time_us.argtypes = [] _lib.llama_time_us.restype = ctypes.c_int64 -# // Various functions for loading a ggml llama model. -# // Allocate (almost) all memory needed for the model. -# // Return NULL on failure -# LLAMA_API struct llama_context * llama_init_from_file( -# const char * path_model, -# struct llama_context_params params); -def llama_init_from_file( - path_model: bytes, params: llama_context_params -) -> llama_context_p: - return _lib.llama_init_from_file(path_model, params) +# LLAMA_API int llama_max_devices (void); +def llama_max_devices() -> int: + return _lib.llama_max_devices() -_lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] -_lib.llama_init_from_file.restype = llama_context_p +_lib.llama_max_devices.argtypes = [] +_lib.llama_max_devices.restype = c_int -# Frees all allocated memory -# LLAMA_API void llama_free(struct llama_context * ctx); -def llama_free(ctx: llama_context_p): - return _lib.llama_free(ctx) +# LLAMA_API bool llama_mmap_supported (void); +def llama_mmap_supported() -> bool: + return _lib.llama_mmap_supported() -_lib.llama_free.argtypes = [llama_context_p] -_lib.llama_free.restype = None +_lib.llama_mmap_supported.argtypes = [] +_lib.llama_mmap_supported.restype = c_bool + + +# LLAMA_API bool llama_mlock_supported(void); +def llama_mlock_supported() -> bool: + return _lib.llama_mlock_supported() + + +_lib.llama_mlock_supported.argtypes = [] +_lib.llama_mlock_supported.restype = c_bool + + +# LLAMA_API int llama_n_vocab(const struct llama_context * ctx); +def llama_n_vocab(ctx: llama_context_p) -> int: + return _lib.llama_n_vocab(ctx) + + +_lib.llama_n_vocab.argtypes = [llama_context_p] +_lib.llama_n_vocab.restype = c_int + + +# LLAMA_API int llama_n_ctx (const struct llama_context * ctx); +def llama_n_ctx(ctx: llama_context_p) -> int: + return _lib.llama_n_ctx(ctx) + + +_lib.llama_n_ctx.argtypes = [llama_context_p] +_lib.llama_n_ctx.restype = c_int + + +# LLAMA_API int llama_n_embd (const struct llama_context * ctx); +def llama_n_embd(ctx: llama_context_p) -> int: + return _lib.llama_n_embd(ctx) + + +_lib.llama_n_embd.argtypes = [llama_context_p] +_lib.llama_n_embd.restype = c_int + + +# LLAMA_API int llama_model_n_vocab(const struct llama_model * model); +def llama_model_n_vocab(model: llama_model_p) -> int: + return _lib.llama_model_n_vocab(model) + + +_lib.llama_model_n_vocab.argtypes = [llama_model_p] +_lib.llama_model_n_vocab.restype = c_int + + +# LLAMA_API int llama_model_n_ctx (const struct llama_model * model); +def llama_model_n_ctx(model: llama_model_p) -> int: + return _lib.llama_model_n_ctx(model) + + +_lib.llama_model_n_ctx.argtypes = [llama_model_p] +_lib.llama_model_n_ctx.restype = c_int + + +# LLAMA_API int llama_model_n_embd (const struct llama_model * model); +def llama_model_n_embd(model: llama_model_p) -> int: + return _lib.llama_model_n_embd(model) + + +_lib.llama_model_n_embd.argtypes = [llama_model_p] +_lib.llama_model_n_embd.restype = c_int + + +# // Get a string describing the model type +# LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); +def llama_model_type(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int: + return _lib.llama_model_type(model, buf, buf_size) + + +_lib.llama_model_type.argtypes = [llama_model_p, c_char_p, c_size_t] +_lib.llama_model_type.restype = c_int # // Returns 0 on success @@ -703,147 +784,17 @@ _lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int, c_int _lib.llama_eval_embd.restype = c_int -# Convert the provided text into tokens. -# The tokens pointer must be large enough to hold the resulting tokens. -# Returns the number of tokens on success, no more than n_max_tokens -# Returns a negative number on failure - the number of tokens that would have been returned -# TODO: not sure if correct -# LLAMA_API int llama_tokenize( -# struct llama_context * ctx, -# const char * text, -# llama_token * tokens, -# int n_max_tokens, -# bool add_bos); -def llama_tokenize( - ctx: llama_context_p, - text: bytes, - tokens, # type: Array[llama_token] - n_max_tokens: Union[c_int, int], - add_bos: Union[c_bool, bool], -) -> int: - return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) +# // Export a static computation graph for context of 511 and batch size of 1 +# // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these +# // parameters here to keep things simple +# // IMPORTANT: do not use for anything else other than debugging and testing! +# LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); +def llama_eval_export(ctx: llama_context_p, fname: bytes) -> int: + return _lib.llama_eval_export(ctx, fname) -_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] -_lib.llama_tokenize.restype = c_int - - -# LLAMA_API int llama_tokenize_with_model( -# const struct llama_model * model, -# const char * text, -# llama_token * tokens, -# int n_max_tokens, -# bool add_bos); -def llama_tokenize_with_model( - model: llama_model_p, - text: bytes, - tokens, # type: Array[llama_token] - n_max_tokens: Union[c_int, int], - add_bos: Union[c_bool, bool], -) -> int: - return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos) - - -# LLAMA_API int llama_n_vocab(const struct llama_context * ctx); -def llama_n_vocab(ctx: llama_context_p) -> int: - return _lib.llama_n_vocab(ctx) - - -_lib.llama_n_vocab.argtypes = [llama_context_p] -_lib.llama_n_vocab.restype = c_int - - -# LLAMA_API int llama_n_ctx (const struct llama_context * ctx); -def llama_n_ctx(ctx: llama_context_p) -> int: - return _lib.llama_n_ctx(ctx) - - -_lib.llama_n_ctx.argtypes = [llama_context_p] -_lib.llama_n_ctx.restype = c_int - - -# LLAMA_API int llama_n_embd (const struct llama_context * ctx); -def llama_n_embd(ctx: llama_context_p) -> int: - return _lib.llama_n_embd(ctx) - - -_lib.llama_n_embd.argtypes = [llama_context_p] -_lib.llama_n_embd.restype = c_int - - -# LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model); -def llama_n_vocab_from_model(model: llama_model_p) -> int: - return _lib.llama_n_vocab_from_model(model) - - -_lib.llama_n_vocab_from_model.argtypes = [llama_model_p] -_lib.llama_n_vocab_from_model.restype = c_int - - -# LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model); -def llama_n_ctx_from_model(model: llama_model_p) -> int: - return _lib.llama_n_ctx_from_model(model) - - -_lib.llama_n_ctx_from_model.argtypes = [llama_model_p] -_lib.llama_n_ctx_from_model.restype = c_int - - -# LLAMA_API int llama_n_embd_from_model (const struct llama_model * model); -def llama_n_embd_from_model(model: llama_model_p) -> int: - return _lib.llama_n_embd_from_model(model) - - -_lib.llama_n_embd_from_model.argtypes = [llama_model_p] -_lib.llama_n_embd_from_model.restype = c_int - - -# // Get the vocabulary as output parameters. -# // Returns number of results. -# LLAMA_API int llama_get_vocab( -# const struct llama_context * ctx, -# const char * * strings, -# float * scores, -# int capacity); -def llama_get_vocab( - ctx: llama_context_p, - strings, # type: Array[c_char_p] # type: ignore - scores, # type: Array[c_float] # type: ignore - capacity: Union[c_int, int], -) -> int: - return _lib.llama_get_vocab(ctx, strings, scores, capacity) - - -_lib.llama_get_vocab.argtypes = [ - llama_context_p, - POINTER(c_char_p), - POINTER(c_float), - c_int, -] -_lib.llama_get_vocab.restype = c_int - - -# LLAMA_API int llama_get_vocab_from_model( -# const struct llama_model * model, -# const char * * strings, -# float * scores, -# int capacity); -def llama_get_vocab_from_model( - model: llama_model_p, - strings, # type: Array[c_char_p] # type: ignore - scores, # type: Array[c_float] # type: ignore - capacity: Union[c_int, int], -) -> int: - return _lib.llama_get_vocab_from_model(model, strings, scores, capacity) - - -_lib.llama_get_vocab_from_model.argtypes = [ - llama_model_p, - POINTER(c_char_p), - POINTER(c_float), - c_int, -] -_lib.llama_get_vocab_from_model.restype = c_int +_lib.llama_eval_export.argtypes = [llama_context_p, c_char_p] +_lib.llama_eval_export.restype = c_int # Token logits obtained from the last call to llama_eval() @@ -875,16 +826,186 @@ _lib.llama_get_embeddings.argtypes = [llama_context_p] _lib.llama_get_embeddings.restype = c_float_p +# // +# // Vocab +# // + + +# LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); +def llama_token_get_text(ctx: llama_context_p, token: llama_token) -> bytes: + return _lib.llama_token_get_text(ctx, token) + + +_lib.llama_token_get_text.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_text.restype = c_char_p + + +# LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); +def llama_token_get_score(ctx: llama_context_p, token: llama_token) -> float: + return _lib.llama_token_get_score(ctx, token) + + +_lib.llama_token_get_score.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_score.restype = c_float + + +# LLAMA_API llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); +def llama_token_get_type(ctx: llama_context_p, token: llama_token) -> int: + return _lib.llama_token_get_type(ctx, token) + + +_lib.llama_token_get_type.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_type.restype = ctypes.c_int + + +# // Special tokens + + +# LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence +def llama_token_bos(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_bos(ctx) + + +_lib.llama_token_bos.argtypes = [llama_context_p] +_lib.llama_token_bos.restype = llama_token + + +# LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence +def llama_token_eos(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_eos(ctx) + + +_lib.llama_token_eos.argtypes = [llama_context_p] +_lib.llama_token_eos.restype = llama_token + + +# LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line +def llama_token_nl(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_nl(ctx) + + +_lib.llama_token_nl.argtypes = [llama_context_p] +_lib.llama_token_nl.restype = llama_token + + +# // +# // Tokenization +# // + + +# Convert the provided text into tokens. +# The tokens pointer must be large enough to hold the resulting tokens. +# Returns the number of tokens on success, no more than n_max_tokens +# Returns a negative number on failure - the number of tokens that would have been returned +# TODO: not sure if correct +# LLAMA_API int llama_tokenize( +# struct llama_context * ctx, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize( + ctx: llama_context_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] +_lib.llama_tokenize.restype = c_int + + +# LLAMA_API int llama_tokenize_bpe( +# struct llama_context * ctx, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize_bpe( + ctx: llama_context_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize_bpe(ctx, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize_bpe.argtypes = [ + llama_context_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_bpe.restype = c_int + + +# LLAMA_API int llama_tokenize_with_model( +# const struct llama_model * model, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize_with_model( + model: llama_model_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize_with_model.argtypes = [ + llama_model_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_with_model.restype = c_int + + # // Token Id -> String. Uses the vocabulary in the provided context -# LLAMA_API const char * llama_token_to_str( +# // Does not write null terminator to the buffer +# LLAMA_API int llama_token_to_str( # const struct llama_context * ctx, -# llama_token token); -def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes: - return _lib.llama_token_to_str(ctx, token) +# llama_token token, +# char * buf, +# int length); +def llama_token_to_str( + ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int +) -> int: + return _lib.llama_token_to_str(ctx, token, buf, length) -_lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] -_lib.llama_token_to_str.restype = c_char_p +_lib.llama_tokenize_with_model.argtypes = [ + llama_model_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_with_model.restype = c_int + + +# LLAMA_API int llama_token_to_str_bpe( +# const struct llama_context * ctx, +# llama_token token, +# char * buf, +# int length); +def llama_token_to_str_bpe( + ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int +) -> int: + return _lib.llama_token_to_str_bpe(ctx, token, buf, length) + + +_lib.llama_token_to_str_bpe.argtypes = [llama_context_p, llama_token, c_char_p, c_int] +_lib.llama_token_to_str_bpe.restype = c_int # LLAMA_API const char * llama_token_to_str_with_model( @@ -897,38 +1018,12 @@ def llama_token_to_str_with_model(model: llama_model_p, token: llama_token) -> b _lib.llama_token_to_str_with_model.argtypes = [llama_model_p, llama_token] _lib.llama_token_to_str_with_model.restype = c_char_p -# Special tokens - - -# LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence -def llama_token_bos() -> int: - return _lib.llama_token_bos() - - -_lib.llama_token_bos.argtypes = [] -_lib.llama_token_bos.restype = llama_token - - -# LLAMA_API llama_token llama_token_eos(); // end-of-sentence -def llama_token_eos() -> int: - return _lib.llama_token_eos() - - -_lib.llama_token_eos.argtypes = [] -_lib.llama_token_eos.restype = llama_token - - -# LLAMA_API llama_token llama_token_nl(); // next-line -def llama_token_nl() -> int: - return _lib.llama_token_nl() - - -_lib.llama_token_nl.argtypes = [] -_lib.llama_token_nl.restype = llama_token - +# // # // Grammar # // + + # LLAMA_API struct llama_grammar * llama_grammar_init( # const llama_grammar_element ** rules, # size_t n_rules, @@ -958,7 +1053,9 @@ _lib.llama_grammar_free.argtypes = [llama_grammar_p] _lib.llama_grammar_free.restype = None -# Sampling functions +# // +# // Sampling functions +# // # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -1157,12 +1254,11 @@ _lib.llama_sample_temperature.argtypes = [ _lib.llama_sample_temperature.restype = None -# /// @details Apply constraints from grammar # LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); def llama_sample_grammar( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - grammar: llama_grammar_p, + grammar, # type: llama_grammar_p ): return _lib.llama_sample_grammar(ctx, candidates, grammar) @@ -1265,9 +1361,11 @@ _lib.llama_sample_token.restype = llama_token # /// @details Accepts the sampled token into the grammar # LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); def llama_grammar_accept_token( - ctx: llama_context_p, grammar: llama_grammar_p, token: llama_token -): - return _lib.llama_grammar_accept_token(ctx, grammar, token) + ctx: llama_context_p, + grammar: llama_grammar_p, + token: llama_token, +) -> None: + _lib.llama_grammar_accept_token(ctx, grammar, token) _lib.llama_grammar_accept_token.argtypes = [ @@ -1316,6 +1414,19 @@ def llama_print_system_info() -> bytes: _lib.llama_print_system_info.argtypes = [] _lib.llama_print_system_info.restype = c_char_p + +# // Set callback for all future logging events. +# // If this is not called, or NULL is supplied, everything is output on stderr. +# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); +def llama_log_set( + log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore +): + return _lib.llama_log_set(log_callback, user_data) + + +_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p] +_lib.llama_log_set.restype = None + ################################################################################################### diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py new file mode 100644 index 0000000..8ff1565 --- /dev/null +++ b/llama_cpp/llama_grammar.py @@ -0,0 +1,1188 @@ +"""C++ implementation of the llama grammar parser.""" +# flake8: noqa +from pathlib import Path +import sys +from ctypes import * # type: ignore +from enum import Enum +from itertools import islice +from typing import ( + Callable, + Generic, + List, + Optional, + OrderedDict, + TextIO, + Tuple, + TypeVar, + Union, + overload, +) + +from . import llama_cpp + +# Type aliases +llama_grammar_element = llama_cpp.llama_grammar_element +llama_grammar_element_p = llama_cpp.llama_grammar_element_p +llama_grammar_p = llama_cpp.llama_grammar_p + +# Type variables +Ptr = TypeVar("Ptr", bound="const_char_p") +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") + + +class Sentinel: + """Used to mark the end of a iterator of std::vector & std::map.""" + + +class LlamaGrammar: + """Keeps reference counts of all the arguments, so that they are not + garbage collected by Python.""" + + def __del__(self) -> None: + """Free the grammar pointer when the object is deleted.""" + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = None + + def __init__( + self, + parsed_grammar: "parse_state", + ) -> None: + """Initialize the grammar pointer from the parsed state.""" + self._grammar_rules = ( + parsed_grammar.c_rules() + ) # type: std.vector[std.vector[LlamaGrammarElement]] + self._n_rules = self._grammar_rules.size() # type: int + self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int + self.grammar = self.init() + + @classmethod + def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": + parsed_grammar = parse(const_char_p(grammar)) # type: parse_state + if parsed_grammar.rules.empty(): + raise ValueError( + f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" + ) + if verbose: + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) + return cls(parsed_grammar) + + @classmethod + def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": + try: + with open(file) as f: + grammar = f.read() + except Exception as err: + raise Exception( + f"{cls.from_file.__name__}: error reading grammar file: {err}" + ) + + if grammar: + return cls.from_string(grammar, verbose=verbose) + + raise ValueError( + f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" + ) + + def init(self) -> None: + # Step 1: Convert LlamaGrammarElement to llama_grammar_element + self._element_lists = [ + [ + llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) + for elem in subvector + ] + for subvector in self._grammar_rules + ] # type: List[List[llama_grammar_element]] + + # Step 2: Convert each list to llama_grammar_element array and get pointer + self._element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in self._element_lists + ] # type: List[Array[llama_grammar_element]] + + # Step 3: Get pointer of each array + self._element_array_pointers = [ + cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( + *self._element_array_pointers + ) + self.grammar = llama_cpp.llama_grammar_init( + self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) + ) + + def reset(self) -> None: + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.init() + + +class LlamaGrammarElement: + def __init__(self, type: "llama_gretype", value: int): + self.type = type + self.value = value # Unicode code point or rule ID + + +class const_char_p: + """C++ implementation of const char *.""" + + def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): + if isinstance(value, const_char_p): + # We're copying an existing const_char_p + self.value = value.value + self.pos = value.pos + (move or 0) + return + + # We're creating a new const_char_p + self.value = value + self.pos = move or 0 + + def __str__(self) -> str: + assert self.value is not None, "null pointer" + return self.value[self.pos :] + + def __getitem__(self, index: int) -> str: + value = str(self) + return value[index] if index < len(value) else "" + + @overload + def __add__(self: Ptr, other: int) -> Ptr: + ... + + @overload + def __add__(self: Ptr, other: Ptr) -> int: + ... + + def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos + other) + if isinstance(other, int) + else self.pos + other.pos + ) + + @overload + def __sub__(self: Ptr, other: int) -> Ptr: + ... + + @overload + def __sub__(self: Ptr, other: Ptr) -> int: + ... + + def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos - other) + if isinstance(other, int) + else self.pos - other.pos + ) + + def __eq__(self: Ptr, other: Ptr) -> bool: + assert self.value == other.value, "comparing pointers from different strings" + return self.pos == other.pos + + def __lt__(self: Ptr, other: Ptr) -> bool: + assert self.value == other.value, "comparing pointers from different strings" + return self.pos < other.pos + + def __gt__(self: Ptr, other: Ptr) -> bool: + assert self.value == other.value, "comparing pointers from different strings" + return self.pos > other.pos + + +class std: + @staticmethod + def string(ptr: const_char_p, length: Optional[int] = None) -> str: + """C++ implementation of std::string constructor.""" + value = str(ptr) + if length is not None: + value = value[:length] + return value + + class vector(Generic[T], List[T]): + """C++ implementation of std::vector.""" + + class iterator: + def __init__(self, vector: "std.vector[T]", index: int): + self._vector = vector + self._index = index + self._version = vector._version + + def _check_version(self): + if self._version != self._vector._version: + raise RuntimeError("Iterator used after vector was modified.") + + def __iter__(self): + return self + + def __next__(self) -> T: + self._check_version() + if self._index >= self._vector.size(): + raise StopIteration + value = self._vector[self._index] + self._index += 1 + return value + + def __add__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index + value) + + def __sub__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index - value) + + def __init__(self): + self._version = 0 + + def modify(self): + # This is a bit of a hack to make sure iterators are invalidated + self._version += 1 + + def push_back(self, value: T) -> None: + self.modify() + self.append(value) + + def pop_back(self) -> None: + self.modify() + if not self.empty(): + self.pop() + + def back(self) -> T: + return self[-1] + + def size(self) -> int: + return len(self) + + def clear(self) -> None: + self.modify() + super().clear() + + def empty(self) -> bool: + return self.size() == 0 + + def data(self) -> "std.vector[T]": + return self + + def resize( + self, + new_size: int, + fill_value_factory: Optional[Callable[[], T]] = None, + ) -> None: + if new_size > self.size(): + if fill_value_factory is None: + raise ValueError("A fill value factory function must be provided.") + self.reserve(new_size, fill_value_factory) + elif new_size < self.size(): + self[:] = self[:new_size] + + def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: + if capacity > self.size(): + fill_value = fill_value_factory() + self.extend([fill_value] * (capacity - self.size())) + + def front(self) -> T: + if not self.empty(): + return self[0] + else: + raise IndexError("Vector is empty.") + + def assign(self, count: int, value: T) -> None: + self.clear() + self.extend([value] * count) + + def insert( + self, + pos: "std.vector[T].iterator", + first: "std.vector[T].iterator", + last: "std.vector[T].iterator", + ) -> None: + self[pos._index : pos._index] = list( + islice(first._vector, first._index, last._index) + ) + + def begin(self) -> "std.vector[T].iterator": + return self.iterator(self, 0) + + def end(self) -> "std.vector[T].iterator": + return self.iterator(self, self.size()) + + class map(Generic[T, U], OrderedDict[T, U]): + """C++ implementation of std::map.""" + + class iterator(Generic[V, W]): + def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): + self._map = _map + self.iter = iter(_map) + self.key = key + self._advance() + + def _sanitize_key(self) -> T: + if isinstance(self.key, Sentinel): + raise StopIteration + return self.key + + def _advance(self) -> None: + try: + while next(self.iter) != self.key: + pass + except StopIteration: + self.key = Sentinel() + + def __next__(self) -> Tuple[T, U]: + key = self._sanitize_key() + if key in self._map: + value = self._map[key] + self._advance() + return key, value + else: + raise StopIteration + + def get(self) -> Tuple[T, U]: + key = self._sanitize_key() + return key, self._map[key] + + @property + def first(self) -> T: + return self._sanitize_key() + + @property + def second(self) -> U: + return self._map[self._sanitize_key()] + + def insert( + self, key: T, value: U + ) -> Tuple["std.map[T, U].iterator[T, U]", bool]: + if key in self: + return self.iterator(self, key), False + else: + self[key] = value + return self.iterator(self, key), True + + def find(self, key: T) -> "std.map[T, U].iterator[T, U]": + if key in self: + return self.iterator(self, key) + else: + return self.end() + + def at(self, key: T) -> U: + if key in self: + return self[key] + else: + raise KeyError("The provided key is not found in the map.") + + def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: + key = iterator.first + if key in self: + del self[key] + + def size(self) -> int: + return len(self) + + def empty(self) -> bool: + return self.size() == 0 + + def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": + try: + keys = sorted(list(self.keys())) # type: ignore + for k in keys: + if k >= key: + return self.iterator(self, k) + raise ValueError("No key found that is not less than the input key") + except TypeError: + raise TypeError("Keys of type T cannot be sorted.") + + def begin(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, next(iter(self))) + + def end(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, Sentinel()) + + +# // grammar element type +# enum llama_gretype { +# // end of rule definition +# LLAMA_GRETYPE_END = 0, + +# // start of alternate definition for rule +# LLAMA_GRETYPE_ALT = 1, + +# // non-terminal element: reference to rule +# LLAMA_GRETYPE_RULE_REF = 2, + +# // terminal element: character (code point) +# LLAMA_GRETYPE_CHAR = 3, + +# // inverse char(s) ([^a], [^a-b] [^abc]) +# LLAMA_GRETYPE_CHAR_NOT = 4, + +# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to +# // be an inclusive range ([a-z]) +# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + +# // modifies a preceding LLAMA_GRETYPE_CHAR or +# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) +# LLAMA_GRETYPE_CHAR_ALT = 6, +# }; +class llama_gretype(Enum): + """grammar element type""" + + LLAMA_GRETYPE_END = 0 # end of rule definition + LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule + LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule + LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) + LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + + +# struct parse_state { +# std::map symbol_ids; +# std::vector> rules; +# std::vector c_rules(); +# }; +class parse_state: + def __init__(self): + self.symbol_ids: std.map[str, int] = std.map() + self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() + + # std::vector parse_state::c_rules() { + # std::vector ret; + # for (const auto & rule : rules) { + # ret.push_back(rule.data()); + # } + # return ret; + # } + def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: + ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] + for rule in self.rules: + ret.push_back(rule.data()) + return ret + + def __repr__(self) -> str: + return ( + f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + ) + + +# struct llama_grammar { +# const std::vector> rules; +# std::vector> stacks; +# }; +# class llama_grammar: +# def __init__( +# self, +# rules: std.vector[std.vector[llama_grammar_element]], +# stacks: std.vector[std.vector[llama_grammar_element]], +# ): +# self.rules = rules +# self.stacks = stacks + + +# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); +# return result.first->second; +# } +def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: + next_id = state.symbol_ids.size() # type: int + result = state.symbol_ids.insert(std.string(src, len), next_id) + return result[0].second # type: ignore + + +# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; +# return next_id; +# } +def generate_symbol_id(state: parse_state, base_name: str) -> int: + next_id = state.symbol_ids.size() # type: int + state.symbol_ids[base_name + "_" + str(next_id)] = next_id + return next_id + + +# void add_rule( +# parse_state & state, +# uint32_t rule_id, +# const std::vector & rule) { +# if (state.rules.size() <= rule_id) { +# state.rules.resize(rule_id + 1); +# } +# state.rules[rule_id] = rule; +# } +def add_rule( + state: parse_state, + rule_id: int, + rule: std.vector[LlamaGrammarElement], +) -> None: + if state.rules.size() <= rule_id: + state.rules.resize( + rule_id + 1, + fill_value_factory=std.vector[LlamaGrammarElement], + ) + state.rules[rule_id] = rule + + +# std::pair decode_utf8(const char * src) { +# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; +# uint8_t first_byte = static_cast(*src); +# uint8_t highbits = first_byte >> 4; +# int len = lookup[highbits]; +# uint8_t mask = (1 << (8 - len)) - 1; +# uint32_t value = first_byte & mask; +# const char * end = src + len; // may overrun! +# const char * pos = src + 1; +# for ( ; pos < end && *pos; pos++) { +# value = (value << 6) + (static_cast(*pos) & 0x3F); +# } +# return std::make_pair(value, pos); +# } +def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: + """Decodes a UTF-8 character from the source string.""" + lookup = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4) + first_byte = ord(src[0]) # type: int + highbits = first_byte >> 4 # type: int + len = lookup[highbits] # type: int + mask = (1 << (8 - len)) - 1 # type: int + value = first_byte & mask # type: int + end = src + len # type: const_char_p # may overrun! + pos = src + 1 # type: const_char_p + while pos < end and pos[0]: + value = (value << 6) + (ord(pos[0]) & 0x3F) + pos += 1 + return value, pos + + +# bool is_word_char(char c) { +# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); +# } +def is_word_char(c: str) -> bool: + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") + + +# std::pair parse_hex(const char * src, int size) { +# const char * pos = src; +# const char * end = src + size; +# uint32_t value = 0; +# for ( ; pos < end && *pos; pos++) { +# value <<= 4; +# char c = *pos; +# if ('a' <= c && c <= 'f') { +# value += c - 'a' + 10; +# } else if ('A' <= c && c <= 'F') { +# value += c - 'A' + 10; +# } else if ('0' <= c && c <= '9') { +# value += c - '0'; +# } else { +# break; +# } +# } +# if (pos != end) { +# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); +# } +# return std::make_pair(value, pos); +# } +def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: + pos = const_char_p(src) # type: const_char_p + end = src + size # type: const_char_p + value = 0 # type: int + while pos < end and pos[0]: + value <<= 4 + c = pos[0] # type: str + if "a" <= c <= "f": + value += ord(c) - ord("a") + 10 + elif "A" <= c <= "F": + value += ord(c) - ord("A") + 10 + elif "0" <= c <= "9": + value += ord(c) - ord("0") + else: + break + pos += 1 + if pos != end: + raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) + return (value, pos) + + +# std::pair parse_char(const char * src) { +# if (*src == '\\') { +# switch (src[1]) { +# case 'x': return parse_hex(src + 2, 2); +# case 'u': return parse_hex(src + 2, 4); +# case 'U': return parse_hex(src + 2, 8); +# case 't': return std::make_pair('\t', src + 2); +# case 'r': return std::make_pair('\r', src + 2); +# case 'n': return std::make_pair('\n', src + 2); +# case '\\': +# case '"': +# case '[': +# case ']': +# return std::make_pair(src[1], src + 2); +# default: +# throw std::runtime_error(std::string("unknown escape at ") + src); +# } +# } else if (*src) { +# return decode_utf8(src); +# } +# throw std::runtime_error("unexpected end of input"); +# } +def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: + if src[0] == "\\": + case = src[1] # type: str + if case == "x": + return parse_hex(src + 2, 2) + elif case == "u": + return parse_hex(src + 2, 4) + elif case == "U": + return parse_hex(src + 2, 8) + elif case == "t": + return (ord("\t"), src + 2) # implicit cast + elif case == "r": + return (ord("\r"), src + 2) # implicit cast + elif case == "n": + return (ord("\n"), src + 2) # implicit cast + elif case in ("\\", '"', "[", "]"): + return (ord(case), src + 2) # implicit cast + else: + raise RuntimeError("unknown escape at " + str(src)) + elif src[0]: + return decode_utf8(src) + else: + raise RuntimeError("unexpected end of input") + + +# const char * parse_name(const char * src) { +# const char * pos = src; +# while (is_word_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting name at ") + src); +# } +# return pos; +# } +def parse_name(src: const_char_p) -> const_char_p: + pos = const_char_p(src) # type: const_char_p + while is_word_char(pos[0]): + pos += 1 + if pos == src: + raise RuntimeError("expecting name at " + str(src)) + return pos + + +# const char * parse_space(const char * src, bool newline_ok) { +# const char * pos = src; +# while (*pos == ' ' || *pos == '\t' || *pos == '#' || +# (newline_ok && (*pos == '\r' || *pos == '\n'))) { +# if (*pos == '#') { +# while (*pos && *pos != '\r' && *pos != '\n') { +# pos++; +# } +# } else { +# pos++; +# } +# } +# return pos; +# } +def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: + pos = const_char_p(src) # type: const_char_p + while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): + if pos[0] == "#": + while pos[0] is not None and pos[0] not in ("\r", "\n"): + pos += 1 + else: + pos += 1 + return pos + + +# const char * parse_sequence( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# std::vector & out_elements, +# bool is_nested) { +def parse_sequence( + state: parse_state, + src: const_char_p, + rule_name: str, + out_elements: std.vector[LlamaGrammarElement], + is_nested: bool, +) -> const_char_p: + # size_t last_sym_start = out_elements.size(); + # const char * pos = src; + last_sym_start = out_elements.size() # type: int + pos = const_char_p(src) # type: const_char_p + # while (*pos) { + while pos[0]: + # if (*pos == '"') { // literal string + # pos++; + # last_sym_start = out_elements.size(); + # while (*pos != '"') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + # } + # pos = parse_space(pos + 1, is_nested); + if pos[0] == '"': # literal string + pos += 1 + last_sym_start = out_elements.size() + while pos[0] != '"': + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] + pos = char_pair[1] + out_elements.push_back( + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) + ) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '[') { // char range(s) + # pos++; + # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + elif pos[0] == "[": # char range(s) + pos += 1 + start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype + # if (*pos == '^') { + # pos++; + # start_type = LLAMA_GRETYPE_CHAR_NOT; + # } + # last_sym_start = out_elements.size(); + if pos[0] == "^": + pos += 1 + start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT + last_sym_start = out_elements.size() + # while (*pos != ']') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # enum llama_gretype type = last_sym_start < out_elements.size() + # ? LLAMA_GRETYPE_CHAR_ALT + # : start_type; + # out_elements.push_back({type, char_pair.first}); + while pos[0] != "]": + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] + pos = char_pair[1] + type = ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT + if last_sym_start < out_elements.size() + else start_type + ) # type: llama_gretype + out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) + # if (pos[0] == '-' && pos[1] != ']') { + # auto endchar_pair = parse_char(pos + 1); + # pos = endchar_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + # } + # } + if pos[0] == "-" and pos[1] != "]": + endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] + pos = endchar_pair[1] + out_elements.push_back( + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + endchar_pair[0], + ) + ) + # pos = parse_space(pos + 1, is_nested); + pos = parse_space(pos + 1, is_nested) + # } else if (is_word_char(*pos)) { // rule reference + # const char * name_end = parse_name(pos); + # uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + # pos = parse_space(name_end, is_nested); + # last_sym_start = out_elements.size(); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + elif is_word_char(pos[0]): # rule reference + name_end = parse_name(pos) # type: const_char_p + ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int + pos = parse_space(name_end, is_nested) + last_sym_start = out_elements.size() + out_elements.push_back( + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) + ) + # } else if (*pos == '(') { // grouping + # // parse nested alternates into synthesized rule + # pos = parse_space(pos + 1, true); + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + # last_sym_start = out_elements.size(); + # // output reference to synthesized rule + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # if (*pos != ')') { + # throw std::runtime_error(std::string("expecting ')' at ") + pos); + # } + # pos = parse_space(pos + 1, is_nested); + elif pos[0] == "(": # grouping + # parse nested alternates into synthesized rule + pos = parse_space(pos + 1, True) + sub_rule_id = generate_symbol_id(state, rule_name) # type: int + pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) + last_sym_start = out_elements.size() + # output reference to synthesized rule + out_elements.push_back( + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) + ) + if pos[0] != ")": + raise RuntimeError("expecting ')' at " + str(pos)) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + # if (last_sym_start == out_elements.size()) { + # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + # } + elif pos[0] in ("*", "+", "?"): # repetition operator + if last_sym_start == out_elements.size(): + raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) + # // apply transformation to previous symbol (last_sym_start to end) according to + # // rewrite rules: + # // S* --> S' ::= S S' | + # // S+ --> S' ::= S S' | S + # // S? --> S' ::= S | + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # std::vector sub_rule; + # // add preceding symbol to generated rule + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + sub_rule_id = generate_symbol_id(state, rule_name) # type: int + sub_rule = std.vector[ + LlamaGrammarElement + ]() # type: std.vector[LlamaGrammarElement] + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + # if (*pos == '*' || *pos == '+') { + # // cause generated rule to recurse + # sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # } + # // mark start of alternate def + # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if pos[0] in ("*", "+"): + sub_rule.push_back( + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) + # if (*pos == '+') { + # // add preceding symbol as alternate only for '+' (otherwise empty) + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + # } + # sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + # add_rule(state, sub_rule_id, sub_rule); + # // in original rule, replace previous symbol with reference to generated rule + # out_elements.resize(last_sym_start); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # pos = parse_space(pos + 1, is_nested); + if pos[0] == "+": + # add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) + add_rule(state, sub_rule_id, sub_rule) + # in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start) + out_elements.push_back( + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) + ) + pos = parse_space(pos + 1, is_nested) + # } else { + # break; + # } + else: + break + # } + # return pos; + # } + return pos + + +# const char * parse_alternates( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# uint32_t rule_id, +# bool is_nested) { +# std::vector rule; +# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); +# while (*pos == '|') { +# rule.push_back({LLAMA_GRETYPE_ALT, 0}); +# pos = parse_space(pos + 1, true); +# pos = parse_sequence(state, pos, rule_name, rule, is_nested); +# } +# rule.push_back({LLAMA_GRETYPE_END, 0}); +# add_rule(state, rule_id, rule); +# return pos; +# } +def parse_alternates( + state: parse_state, + src: const_char_p, + rule_name: str, + rule_id: int, + is_nested: bool, +) -> const_char_p: + rule = std.vector() # type: std.vector[LlamaGrammarElement] + pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p + while pos[0] == "|": + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) + pos = parse_space(pos + 1, True) + pos = parse_sequence(state, pos, rule_name, rule, is_nested) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) + add_rule(state, rule_id, rule) + return pos + + +# const char * parse_rule(parse_state & state, const char * src) { +# const char * name_end = parse_name(src); +# const char * pos = parse_space(name_end, false); +# size_t name_len = name_end - src; +# uint32_t rule_id = get_symbol_id(state, src, name_len); +# const std::string name(src, name_len); + +# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { +# throw std::runtime_error(std::string("expecting ::= at ") + pos); +# } +# pos = parse_space(pos + 3, true); + +# pos = parse_alternates(state, pos, name, rule_id, false); + + +# if (*pos == '\r') { +# pos += pos[1] == '\n' ? 2 : 1; +# } else if (*pos == '\n') { +# pos++; +# } else if (*pos) { +# throw std::runtime_error(std::string("expecting newline or end at ") + pos); +# } +# return parse_space(pos, true); +# } +def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: + name_end = parse_name(src) # type: const_char_p + pos = parse_space(name_end, False) # type: const_char_p + name_len = name_end - src # type: int + rule_id = get_symbol_id(state, src, name_len) # type: int + name = std.string(src, name_len) # type: str + + if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): + raise RuntimeError("expecting ::= at " + str(pos)) + + pos = parse_space(pos + 3, True) # type: const_char_p + pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p + + if pos[0] == "\r": + pos += 2 if pos[1] == "\n" else 1 + elif pos[0] == "\n": + pos += 1 + elif pos[0]: + raise RuntimeError("expecting newline or end at " + str(pos)) + return parse_space(pos, True) + + +# parse_state parse(const char * src) { +# try { +# parse_state state; +# const char * pos = parse_space(src, true); +# while (*pos) { +# pos = parse_rule(state, pos); +# } +# return state; +# } catch (const std::exception & err) { +# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); +# return parse_state(); +# } +# } +def parse(src: const_char_p) -> parse_state: + try: + state = parse_state() # type: parse_state + pos = parse_space(src, True) # type: const_char_p + while pos[0]: + pos = parse_rule(state, pos) + return state + except Exception as err: + print(f"{parse.__name__}: error parsing grammar: {err}") + return parse_state() + + +# void print_grammar_char(FILE * file, uint32_t c) { +# if (0x20 <= c && c <= 0x7f) { +# fprintf(file, "%c", static_cast(c)); +# } else { +# // cop out of encoding UTF-8 +# fprintf(file, "", c); +# } +# } +def print_grammar_char(file: TextIO, c: int) -> None: + if 0x20 <= c and c <= 0x7F: + file.write(chr(c)) + else: + # cop out of encoding UTF-8 + file.write(f"") + + +# bool is_char_element(llama_grammar_element elem) { +# switch (elem.type) { +# case LLAMA_GRETYPE_CHAR: return true; +# case LLAMA_GRETYPE_CHAR_NOT: return true; +# case LLAMA_GRETYPE_CHAR_ALT: return true; +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; +# default: return false; +# } +# } +def is_char_element(elem: LlamaGrammarElement) -> bool: + return elem.type in ( + llama_gretype.LLAMA_GRETYPE_CHAR, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ) + + +# void print_rule( +# FILE * file, +# uint32_t rule_id, +# const std::vector & rule, +# const std::map & symbol_id_names) { +def print_rule( + file: TextIO, + rule_id: int, + rule: std.vector[LlamaGrammarElement], + symbol_id_names: std.map[int, str], +) -> None: + # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + # throw std::runtime_error( + # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + # } + # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + + str(rule_id) + ) + print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") + # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + # llama_grammar_element elem = rule[i]; + # switch (elem.type) { + # case LLAMA_GRETYPE_END: + # throw std::runtime_error( + # "unexpected end of rule: " + std::to_string(rule_id) + "," + + # std::to_string(i)); + # case LLAMA_GRETYPE_ALT: + # fprintf(file, "| "); + # break; + # case LLAMA_GRETYPE_RULE_REF: + # fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + # break; + # case LLAMA_GRETYPE_CHAR: + # fprintf(file, "["); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_NOT: + # fprintf(file, "[^"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # fprintf(file, "-"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_ALT: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # print_grammar_char(file, elem.value); + # break; + # } + for i, elem in enumerate(rule[:-1]): + case = elem.type # type: llama_gretype + if case is llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "unexpected end of rule: " + str(rule_id) + "," + str(i) + ) + elif case is llama_gretype.LLAMA_GRETYPE_ALT: + print("| ", file=file, end="") + elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: + print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") + elif case is llama_gretype.LLAMA_GRETYPE_CHAR: + print("[", file=file, end="") + print_grammar_char(file, elem.value) + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: + print("[^", file=file, end="") + print_grammar_char(file, elem.value) + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print("-", file=file, end="") + print_grammar_char(file, elem.value) + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print_grammar_char(file, elem.value) + # if (is_char_element(elem)) { + # switch (rule[i + 1].type) { + # case LLAMA_GRETYPE_CHAR_ALT: + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # break; + # default: + # fprintf(file, "] "); + if is_char_element(elem): + if rule[i + 1].type in ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ): + pass + else: + print("] ", file=file, end="") + # } + # } + # } + # fprintf(file, "\n"); + # } + print(file=file) + + +# void print_grammar(FILE * file, const parse_state & state) { +# try { +# std::map symbol_id_names; +# for (auto kv : state.symbol_ids) { +# symbol_id_names[kv.second] = kv.first; +# } +# for (size_t i = 0, end = state.rules.size(); i < end; i++) { +# // fprintf(file, "%zu: ", i); +# // print_rule_binary(file, state.rules[i]); +# print_rule(file, i, state.rules[i], symbol_id_names); +# // fprintf(file, "\n"); +# } +# } catch (const std::exception & err) { +# fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); +# } +# } +def print_grammar(file: TextIO, state: parse_state) -> None: + try: + symbol_id_names = std.map() # type: std.map[int, str] + for kv in state.symbol_ids.items(): + symbol_id_names[kv[1]] = kv[0] + + for i, rule in enumerate(state.rules): + print_rule(file, i, rule, symbol_id_names) + except Exception as err: + print( + f"{print_grammar.__name__}: error printing grammar: {err}", + file=sys.stderr, + ) diff --git a/llama_cpp/py.typed b/llama_cpp/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 58b5551..0dd0749 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -108,6 +108,10 @@ class Settings(BaseSettings): default=None, description="TEMPORARY", ) + mul_mat_q: Optional[bool] = Field( + default=None, + description="TEMPORARY", + ) class ErrorResponse(TypedDict): diff --git a/llama_cpp/utils.py b/llama_cpp/utils.py new file mode 100644 index 0000000..c14f53f --- /dev/null +++ b/llama_cpp/utils.py @@ -0,0 +1,38 @@ +import os +import sys + + +class suppress_stdout_stderr(object): + # Oddly enough this works better than the contextlib version + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() diff --git a/pyproject.toml b/pyproject.toml index 6272bb9..84ecf37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "llama_cpp_python" -version = "0.1.77" +version = "0.1.78" description = "Python bindings for the llama.cpp library" readme = "README.md" license = { text = "MIT" } diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 41c6741..f5fe98d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 41c674161fb2459bdf7806d1eebead15bc5d046e +Subproject commit f5fe98d11bdf9e7797bcfb05c0c3601ffc4b9d26