diff --git a/CHANGELOG.md b/CHANGELOG.md
index c0748ee..4fff919 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [0.2.32]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@504dc37be8446fb09b1ede70300250ad41be32a2
+- fix: from_json_schema oneof/anyof bug by @jndiogo in d3f5528ca8bcb9d69d4f27e21631e911f1fb9bfe
+- fix: pass chat handler not chat formatter for huggingface autotokenizer and tokenizer_config formats by @abetlen in 24f39454e91cf5dddbc4b6041aead4accc7c7a2d
+- feat: Add add_generation_prompt option for jinja2chatformatter by @abetlen in 7f3209b1eb4ad3260ba063801fab80a8c25a2f4c
+- feat: Add Jinja2ChatFormatter by @abetlen in be09318c26add8674ce494ae7cc480cce72a4146
+- feat: Expose gguf model metadata in metadata property by @abetlen in 5a34c57e5479e50c99aba9b38218cc48e6560b81
+
+## [0.2.31]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@a5cacb22b2114fd9adf61c00cbb237384d86bced
+- fix: Mirostat sampling now passes correct type to ctypes and tracks state during generation by @abetlen in 3babe3512cb95743108f2b595210c38ed6f1b904
+- fix: Python3.8 support in server by @abetlen in 141293a75b564a8699e0acba1da24d9aa1cf0ab1
+
+## [0.2.30]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@57e2a7a52a819883f40dada8a2edc24ecf48186b
+- feat(server): Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files by @abetlen in b8fc1c7d83ad4a9207c707ba1d954fe580286a01
+- feat: Integration of Jinja2 Templating for chat formats by @teleprint-me in #875
+- fix: Offload KQV by default by @abetlen in 48c3b77e6f558a9899de0e1155c7dc0c7958d8e8
+- fix: Support Accept text/event-stream in chat and completion endpoints, resolves #1083 by @aniljava in #1088
+- fix(cli): allow passing n_ctx=0 to openAI API server args to use model n_ctx_train field per #1015 by @K-Mistele in #1093
+
## [0.2.29]
- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b
diff --git a/Makefile b/Makefile
index e930609..5ed3fa2 100644
--- a/Makefile
+++ b/Makefile
@@ -10,22 +10,22 @@ deps:
python3 -m pip install -e ".[all]"
build:
- python3 -m pip install -e .
+ python3 -m pip install --verbose -e .
build.cuda:
- CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install -e .
+ CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install --verbose -e .
build.opencl:
- CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e .
+ CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e .
build.openblas:
- CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e .
+ CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e .
build.blis:
- CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install -e .
+ CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install --verbose -e .
build.metal:
- CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install -e .
+ CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e .
build.sdist:
python3 -m build --sdist
diff --git a/README.md b/README.md
index ad5d0f1..f97ea0f 100644
--- a/README.md
+++ b/README.md
@@ -113,6 +113,10 @@ See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to
### MacOS Notes
+Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
+
+#### M1 Mac Performance Issue
+
Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example:
```
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
@@ -120,7 +124,13 @@ bash Miniforge3-MacOSX-arm64.sh
```
Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
-Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
+#### M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`
+
+Try installing with
+
+```
+CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_APPLE_SILICON_PROCESSOR=arm64 -DLLAMA_METAL=on" pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python
+```
### Upgrading and Reinstalling
diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py
index 65206bf..dda8335 100644
--- a/llama_cpp/__init__.py
+++ b/llama_cpp/__init__.py
@@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
-__version__ = "0.2.29"
\ No newline at end of file
+__version__ = "0.2.32"
\ No newline at end of file
diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
new file mode 100644
index 0000000..ec47c42
--- /dev/null
+++ b/llama_cpp/_internals.py
@@ -0,0 +1,795 @@
+from __future__ import annotations
+
+import os
+import ctypes
+
+from typing import (
+ List,
+ Optional,
+ Sequence,
+)
+from dataclasses import dataclass, field
+
+import numpy as np
+import numpy.typing as npt
+
+from .llama_types import *
+from .llama_grammar import LlamaGrammar
+
+import llama_cpp.llama_cpp as llama_cpp
+
+from ._utils import suppress_stdout_stderr
+
+
+# Python wrappers over llama.h structs
+
+
+class _LlamaModel:
+ """Intermediate Python wrapper for a llama.cpp llama_model.
+ NOTE: For stability it's recommended you use the Llama class instead."""
+
+ _llama_free_model = None
+ # NOTE: this must be "saved" here to avoid exceptions when calling __del__
+ _suppress_stdout_stderr = suppress_stdout_stderr
+
+ def __init__(
+ self,
+ *,
+ path_model: str,
+ params: llama_cpp.llama_model_params,
+ verbose: bool = True,
+ ):
+ self.path_model = path_model
+ self.params = params
+ self.verbose = verbose
+
+ self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
+
+ if not os.path.exists(path_model):
+ raise ValueError(f"Model path does not exist: {path_model}")
+
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ self.model = llama_cpp.llama_load_model_from_file(
+ self.path_model.encode("utf-8"), self.params
+ )
+
+ def __del__(self):
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ if self.model is not None and self._llama_free_model is not None:
+ self._llama_free_model(self.model)
+ self.model = None
+
+ def vocab_type(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_vocab_type(self.model)
+
+ def n_vocab(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_n_vocab(self.model)
+
+ def n_ctx_train(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_n_ctx_train(self.model)
+
+ def n_embd(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_n_embd(self.model)
+
+ def rope_freq_scale_train(self) -> float:
+ assert self.model is not None
+ return llama_cpp.llama_rope_freq_scale_train(self.model)
+
+ def desc(self) -> str:
+ assert self.model is not None
+ buf = ctypes.create_string_buffer(1024)
+ llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
+ return buf.value.decode("utf-8")
+
+ def size(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_model_size(self.model)
+
+ def n_params(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_model_n_params(self.model)
+
+ def get_tensor(self, name: str) -> ctypes.c_void_p:
+ assert self.model is not None
+ return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
+
+ def apply_lora_from_file(
+ self,
+ lora_path: str,
+ scale: float,
+ path_base_model: Optional[str],
+ n_threads: int,
+ ):
+ assert self.model is not None
+ return llama_cpp.llama_model_apply_lora_from_file(
+ self.model,
+ lora_path.encode("utf-8"),
+ scale,
+ path_base_model.encode("utf-8")
+ if path_base_model is not None
+ else llama_cpp.c_char_p(0),
+ n_threads,
+ )
+
+ # Vocab
+
+ def token_get_text(self, token: int) -> str:
+ # TODO: Fix
+ assert self.model is not None
+ return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
+
+ def token_get_score(self, token: int) -> float:
+ assert self.model is not None
+ return llama_cpp.llama_token_get_score(self.model, token)
+
+ def token_get_type(self, token: int) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_get_type(self.model, token)
+
+ # Special tokens
+
+ def token_bos(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_bos(self.model)
+
+ def token_eos(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_eos(self.model)
+
+ def token_nl(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_nl(self.model)
+
+ def token_prefix(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_prefix(self.model)
+
+ def token_middle(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_middle(self.model)
+
+ def token_suffix(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_suffix(self.model)
+
+ def token_eot(self) -> int:
+ assert self.model is not None
+ return llama_cpp.llama_token_eot(self.model)
+
+ # Tokenization
+
+ def tokenize(self, text: bytes, add_bos: bool, special: bool):
+ assert self.model is not None
+ n_ctx = self.n_ctx_train()
+ tokens = (llama_cpp.llama_token * n_ctx)()
+ n_tokens = llama_cpp.llama_tokenize(
+ self.model, text, len(text), tokens, n_ctx, add_bos, special
+ )
+ if n_tokens < 0:
+ n_tokens = abs(n_tokens)
+ tokens = (llama_cpp.llama_token * n_tokens)()
+ n_tokens = llama_cpp.llama_tokenize(
+ self.model, text, len(text), tokens, n_tokens, add_bos, special
+ )
+ if n_tokens < 0:
+ raise RuntimeError(
+ f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
+ )
+ return list(tokens[:n_tokens])
+
+ def token_to_piece(self, token: int) -> bytes:
+ assert self.model is not None
+ buf = ctypes.create_string_buffer(32)
+ llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
+ return bytes(buf)
+
+ def detokenize(self, tokens: List[int]) -> bytes:
+ assert self.model is not None
+ output = b""
+ size = 32
+ buffer = (ctypes.c_char * size)()
+ for token in tokens:
+ n = llama_cpp.llama_token_to_piece(
+ self.model, llama_cpp.llama_token(token), buffer, size
+ )
+ assert n <= size
+ output += bytes(buffer[:n])
+ # NOTE: Llama1 models automatically added a space at the start of the prompt
+ # this line removes a leading space if the first token is a beginning of sentence token
+ return (
+ output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output
+ )
+
+ # Extra
+ def metadata(self) -> Dict[str, str]:
+ assert self.model is not None
+ metadata: Dict[str, str] = {}
+ buffer_size = 1024
+ buffer = ctypes.create_string_buffer(buffer_size)
+ # zero the buffer
+ buffer.value = b'\0' * buffer_size
+ # iterate over model keys
+ for i in range(llama_cpp.llama_model_meta_count(self.model)):
+ nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
+ if nbytes > buffer_size:
+ buffer_size = nbytes
+ buffer = ctypes.create_string_buffer(buffer_size)
+ nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
+ key = buffer.value.decode("utf-8")
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
+ if nbytes > buffer_size:
+ buffer_size = nbytes
+ buffer = ctypes.create_string_buffer(buffer_size)
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
+ value = buffer.value.decode("utf-8")
+ metadata[key] = value
+ return metadata
+
+ @staticmethod
+ def default_params():
+ """Get the default llama_model_params."""
+ return llama_cpp.llama_model_default_params()
+
+
+class _LlamaContext:
+ """Intermediate Python wrapper for a llama.cpp llama_context.
+ NOTE: For stability it's recommended you use the Llama class instead."""
+
+ _llama_free = None
+ # NOTE: this must be "saved" here to avoid exceptions when calling __del__
+ _suppress_stdout_stderr = suppress_stdout_stderr
+
+ def __init__(
+ self,
+ *,
+ model: _LlamaModel,
+ params: llama_cpp.llama_context_params,
+ verbose: bool = True,
+ ):
+ self.model = model
+ self.params = params
+ self.verbose = verbose
+
+ self._llama_free = llama_cpp._lib.llama_free # type: ignore
+
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ self.ctx = llama_cpp.llama_new_context_with_model(
+ self.model.model, self.params
+ )
+
+ def __del__(self):
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ if self.ctx is not None and self._llama_free is not None:
+ self._llama_free(self.ctx)
+ self.ctx = None
+
+ def n_ctx(self) -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_n_ctx(self.ctx)
+
+ def kv_cache_clear(self):
+ assert self.ctx is not None
+ llama_cpp.llama_kv_cache_clear(self.ctx)
+
+ def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
+ assert self.ctx is not None
+ llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
+
+ def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
+ assert self.ctx is not None
+ llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
+
+ def kv_cache_seq_keep(self, seq_id: int):
+ assert self.ctx is not None
+ llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
+
+ def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
+ assert self.ctx is not None
+ llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift)
+
+ def get_state_size(self) -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_get_state_size(self.ctx)
+
+ # TODO: copy_state_data
+
+ # TODO: set_state_data
+
+ # TODO: llama_load_session_file
+
+ # TODO: llama_save_session_file
+
+ def decode(self, batch: "_LlamaBatch"):
+ assert self.ctx is not None
+ assert batch.batch is not None
+ return_code = llama_cpp.llama_decode(
+ ctx=self.ctx,
+ batch=batch.batch,
+ )
+ if return_code != 0:
+ raise RuntimeError(f"llama_decode returned {return_code}")
+
+ def set_n_threads(self, n_threads: int, n_threads_batch: int):
+ assert self.ctx is not None
+ llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
+
+ def get_logits(self):
+ assert self.ctx is not None
+ return llama_cpp.llama_get_logits(self.ctx)
+
+ def get_logits_ith(self, i: int):
+ assert self.ctx is not None
+ return llama_cpp.llama_get_logits_ith(self.ctx, i)
+
+ def get_embeddings(self):
+ assert self.ctx is not None
+ return llama_cpp.llama_get_embeddings(self.ctx)
+
+ # Sampling functions
+
+ def set_rng_seed(self, seed: int):
+ assert self.ctx is not None
+ llama_cpp.llama_set_rng_seed(self.ctx, seed)
+
+ def sample_repetition_penalties(
+ self,
+ candidates: "_LlamaTokenDataArray",
+ last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
+ penalty_last_n: int,
+ penalty_repeat: float,
+ penalty_freq: float,
+ penalty_present: float,
+ ):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_repetition_penalties(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ last_tokens_data,
+ penalty_last_n,
+ penalty_repeat,
+ penalty_freq,
+ penalty_present,
+ )
+
+ def sample_classifier_free_guidance(
+ self,
+ candidates: "_LlamaTokenDataArray",
+ guidance_ctx: "_LlamaContext",
+ scale: float,
+ ):
+ assert self.ctx is not None
+ assert guidance_ctx.ctx is not None
+ llama_cpp.llama_sample_classifier_free_guidance(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ guidance_ctx.ctx,
+ scale,
+ )
+
+ def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_softmax(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ )
+
+ def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_top_k(
+ self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
+ )
+
+ def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_top_p(
+ self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ )
+
+ def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_min_p(
+ self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ )
+
+ def sample_tail_free(
+ self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
+ ):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_tail_free(
+ self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
+ )
+
+ def sample_typical(
+ self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
+ ):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_typical(
+ self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ )
+
+ def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
+ assert self.ctx is not None
+ llama_cpp.llama_sample_temp(
+ self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
+ )
+
+ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
+ assert self.ctx is not None
+ assert grammar.grammar is not None
+ llama_cpp.llama_sample_grammar(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ grammar.grammar,
+ )
+
+ def sample_token_mirostat(
+ self,
+ candidates: "_LlamaTokenDataArray",
+ tau: float,
+ eta: float,
+ m: int,
+ mu: ctypes._Pointer[ctypes.c_float], # type: ignore
+ ) -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_sample_token_mirostat(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ tau,
+ eta,
+ m,
+ mu,
+ )
+
+ def sample_token_mirostat_v2(
+ self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
+ ) -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_sample_token_mirostat_v2(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ tau,
+ eta,
+ mu,
+ )
+
+ def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_sample_token_greedy(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ )
+
+ def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
+ assert self.ctx is not None
+ return llama_cpp.llama_sample_token(
+ self.ctx,
+ ctypes.byref(candidates.candidates), # type: ignore
+ )
+
+ # Grammar
+ def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
+ assert self.ctx is not None
+ assert grammar.grammar is not None
+ llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token)
+
+ def reset_timings(self):
+ assert self.ctx is not None
+ llama_cpp.llama_reset_timings(self.ctx)
+
+ def print_timings(self):
+ assert self.ctx is not None
+ llama_cpp.llama_print_timings(self.ctx)
+
+ # Utility functions
+ @staticmethod
+ def default_params():
+ """Get the default llama_context_params."""
+ return llama_cpp.llama_context_default_params()
+
+
+class _LlamaBatch:
+ _llama_batch_free = None
+ # NOTE: this must be "saved" here to avoid exceptions when calling __del__
+ _suppress_stdout_stderr = suppress_stdout_stderr
+
+ def __init__(
+ self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
+ ):
+ self.n_tokens = n_tokens
+ self.embd = embd
+ self.n_seq_max = n_seq_max
+ self.verbose = verbose
+
+ self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
+
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ self.batch = llama_cpp.llama_batch_init(
+ self.n_tokens, self.embd, self.n_seq_max
+ )
+
+ def __del__(self):
+ with self._suppress_stdout_stderr(disable=self.verbose):
+ if self.batch is not None and self._llama_batch_free is not None:
+ self._llama_batch_free(self.batch)
+ self.batch = None
+
+ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
+ assert self.batch is not None
+ n_tokens = len(batch)
+ self.batch.n_tokens = n_tokens
+ for i in range(n_tokens):
+ self.batch.token[i] = batch[i]
+ self.batch.pos[i] = n_past + i
+ self.batch.seq_id[i][0] = 0
+ self.batch.n_seq_id[i] = 1
+ self.batch.logits[i] = logits_all
+ self.batch.logits[n_tokens - 1] = True
+
+
+class _LlamaTokenDataArray:
+ def __init__(self, *, n_vocab: int):
+ self.n_vocab = n_vocab
+ self.candidates_data = np.array(
+ [],
+ dtype=np.dtype(
+ [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
+ ),
+ )
+ self.candidates_data.resize(3, self.n_vocab, refcheck=False)
+ self.candidates = llama_cpp.llama_token_data_array(
+ data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
+ size=self.n_vocab,
+ sorted=False,
+ )
+ self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
+ self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
+
+ def copy_logits(self, logits: npt.NDArray[np.single]):
+ self.candidates_data["id"][:] = self.default_candidates_data_id
+ self.candidates_data["logit"][:] = logits
+ self.candidates_data["p"][:] = self.default_candidates_data_p
+ self.candidates.data = self.candidates_data.ctypes.data_as(
+ llama_cpp.llama_token_data_p
+ )
+ self.candidates.sorted = llama_cpp.c_bool(False)
+ self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
+
+
+# Python wrappers over common/common
+def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
+ n_tokens = len(text) + 1 if add_bos else len(text)
+ result = (llama_cpp.llama_token * n_tokens)()
+ n_tokens = llama_cpp.llama_tokenize(
+ model.model,
+ text.encode("utf-8"),
+ len(text),
+ result,
+ n_tokens,
+ add_bos,
+ special,
+ )
+ if n_tokens < 0:
+ result = (llama_cpp.llama_token * -n_tokens)()
+ check = llama_cpp.llama_tokenize(
+ model.model,
+ text.encode("utf-8"),
+ len(text),
+ result,
+ len(result),
+ add_bos,
+ special,
+ )
+ if check != -n_tokens:
+ raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
+ else:
+ result = result[:n_tokens]
+ return list(result)
+
+
+def _token_to_piece(model: _LlamaModel, token: int) -> str:
+ assert model.model is not None
+ result = (ctypes.c_char * 8)(0)
+ n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
+ if n_tokens < 0:
+ result = (ctypes.c_char * -n_tokens)(0)
+ check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
+ if check != -n_tokens:
+ raise RuntimeError(f"Failed to get piece: token={token}")
+ else:
+ result = result[:n_tokens]
+ return bytes(result).decode("utf-8")
+
+
+def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str:
+ bos_id = model.token_bos()
+ result = ""
+ for i, token in enumerate(tokens):
+ piece = _token_to_piece(model, token)
+ if (
+ (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0)
+ ) and piece[0] == " ":
+ piece = piece[1:]
+ result += piece
+ return result
+
+
+def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str:
+ result = ""
+ for token in tokens:
+ piece = _token_to_piece(model, token)
+ result += piece
+ return result
+
+
+def _should_add_bos(model: _LlamaModel) -> bool:
+ assert model.model is not None
+ add_bos = llama_cpp.llama_add_bos_token(model.model)
+ if add_bos != -1:
+ return add_bos != 0
+ else:
+ return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
+
+
+# Python wrappers over common/sampling structs
+
+
+@dataclass
+class _LlamaSamplingParams:
+ n_prev: int = 64
+ n_probs: int = 0
+ top_k: int = 40
+ top_p: float = 0.95
+ min_p: float = 0.05
+ tfs_z: float = 1.00
+ typical_p: float = 1.00
+ temp: float = 0.80
+ penalty_last_n: int = 64
+ penalty_repeat: float = 1.10
+ penalty_freq: float = 0.00
+ penalty_present: float = 0.00
+ mirostat: int = 0
+ mirostat_tau: float = 5.00
+ mirostat_eta: float = 0.10
+ penalize_nl: bool = True
+
+ grammar: str = ""
+
+ cfg_negative_prompt: str = ""
+ cfg_scale: float = 1.00
+
+ logit_bias: dict[int, float] = field(default_factory=dict)
+
+
+@dataclass
+class _LlamaSamplingContext:
+ params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams)
+ mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
+ grammar: Optional[LlamaGrammar] = None
+ # NOTE: Missing parsed_grammar
+ prev: list[int] = field(default_factory=list)
+ cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
+
+ def reset(self):
+ self.prev = []
+ self.cur = []
+ if self.grammar is not None:
+ self.grammar.reset()
+
+ def cp(self):
+ return _LlamaSamplingContext(
+ params=self.params,
+ mirostat_mu=self.mirostat_mu,
+ grammar=self.grammar,
+ prev=self.prev.copy(),
+ cur=self.cur.copy(),
+ )
+
+ def last(self) -> Optional[int]:
+ if len(self.prev) > 0:
+ return self.prev[-1]
+ else:
+ return None
+
+ def prev_str(self, ctx_main: _LlamaContext, n: int) -> str:
+ return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
+
+ def sample(
+ self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
+ ):
+ n_vocab = ctx_main.model.n_vocab()
+ id: int = 0
+
+ if logits_array is None:
+ logits = ctx_main.get_logits_ith(idx)
+ logits_array = np.array(
+ ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents,
+ dtype=np.single,
+ )
+
+ # apply logit_bias
+ for token, logit_bias in self.params.logit_bias.items():
+ logits_array[token] += logit_bias
+
+ token_data_array = _LlamaTokenDataArray(
+ n_vocab=n_vocab
+ ) # TODO: Only create this once
+ token_data_array.copy_logits(logits_array)
+
+ if ctx_cfg is not None:
+ ctx_main.sample_classifier_free_guidance(
+ token_data_array, ctx_cfg, self.params.cfg_scale
+ )
+
+ # apply penalties
+ if len(self.prev) > 0:
+ nl_token = ctx_main.model.token_nl()
+ nl_logit = logits_array[nl_token]
+ if self.params.penalty_last_n > 0:
+ ctx_main.sample_repetition_penalties(
+ token_data_array,
+ # TODO: Only create this once
+ (llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
+ self.params.penalty_last_n,
+ self.params.penalty_repeat,
+ self.params.penalty_freq,
+ self.params.penalty_present,
+ )
+ if not self.params.penalize_nl:
+ token_data_array.candidates_data["logit"][nl_token] = nl_logit
+
+ if self.grammar is not None:
+ ctx_main.sample_grammar(token_data_array, self.grammar)
+
+ if self.params.temp < 0:
+ ctx_main.sample_softmax(token_data_array)
+ id = token_data_array.candidates_data["id"][0]
+ elif self.params.temp == 0:
+ id = ctx_main.sample_token_greedy(token_data_array)
+ else:
+ if self.params.mirostat == 1:
+ mirostat_m = 100
+ ctx_main.sample_temp(token_data_array, self.params.temp)
+ id = ctx_main.sample_token_mirostat(
+ token_data_array,
+ self.params.mirostat_tau,
+ self.params.mirostat_eta,
+ mirostat_m,
+ ctypes.pointer(self.mirostat_mu),
+ )
+ elif self.params.mirostat == 2:
+ ctx_main.sample_temp(token_data_array, self.params.temp)
+ id = ctx_main.sample_token_mirostat_v2(
+ token_data_array,
+ self.params.mirostat_tau,
+ self.params.mirostat_eta,
+ ctypes.pointer(self.mirostat_mu),
+ )
+ else:
+ min_keep = max(1, self.params.n_probs)
+ ctx_main.sample_top_k(
+ token_data_array, self.params.top_k, min_keep=min_keep
+ )
+ ctx_main.sample_tail_free(
+ token_data_array, self.params.tfs_z, min_keep=min_keep
+ )
+ ctx_main.sample_typical(
+ token_data_array, self.params.typical_p, min_keep=min_keep
+ )
+ ctx_main.sample_top_p(
+ token_data_array, self.params.top_p, min_keep=min_keep
+ )
+ ctx_main.sample_min_p(
+ token_data_array, self.params.min_p, min_keep=min_keep
+ )
+ ctx_main.sample_temp(token_data_array, self.params.temp)
+ id = ctx_main.sample_token(token_data_array)
+ return id
+
+ def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
+ if apply_grammar and self.grammar is not None:
+ ctx_main.grammar_accept_token(self.grammar, id)
+ self.prev.append(id)
\ No newline at end of file
diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py
index f7b6ba6..4a10647 100644
--- a/llama_cpp/_utils.py
+++ b/llama_cpp/_utils.py
@@ -1,7 +1,8 @@
import os
import sys
-import sys, traceback
+import sys
+from typing import Any, Dict
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
outnull_file = open(os.devnull, "w")
@@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
self.os.close(self.old_stdout_fileno)
self.os.close(self.old_stderr_fileno)
+
+
+class MetaSingleton(type):
+ """
+ Metaclass for implementing the Singleton pattern.
+ """
+
+ _instances: Dict[type, Any] = {}
+
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
+ if cls not in cls._instances:
+ cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class Singleton(object, metaclass=MetaSingleton):
+ """
+ Base class for implementing the Singleton pattern.
+ """
+
+ def __init__(self):
+ super(Singleton, self).__init__()
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index e4be9d1..5c66bcf 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -1,9 +1,10 @@
+from __future__ import annotations
+
import os
import sys
import uuid
import time
import multiprocessing
-from abc import ABC, abstractmethod
from typing import (
List,
Optional,
@@ -12,16 +13,20 @@ from typing import (
Sequence,
Iterator,
Deque,
- Tuple,
Callable,
)
-from collections import deque, OrderedDict
+from collections import deque
-import diskcache
import ctypes
from .llama_types import *
from .llama_grammar import LlamaGrammar
+from .llama_cache import (
+ BaseLlamaCache,
+ LlamaCache, # type: ignore
+ LlamaDiskCache, # type: ignore
+ LlamaRAMCache, # type: ignore
+)
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
@@ -29,694 +34,12 @@ import numpy as np
import numpy.typing as npt
from ._utils import suppress_stdout_stderr
-
-
-class BaseLlamaCache(ABC):
- """Base cache class for a llama.cpp model."""
-
- def __init__(self, capacity_bytes: int = (2 << 30)):
- self.capacity_bytes = capacity_bytes
-
- @property
- @abstractmethod
- def cache_size(self) -> int:
- raise NotImplementedError
-
- def _find_longest_prefix_key(
- self,
- key: Tuple[int, ...],
- ) -> Optional[Tuple[int, ...]]:
- pass
-
- @abstractmethod
- def __getitem__(self, key: Sequence[int]) -> "LlamaState":
- raise NotImplementedError
-
- @abstractmethod
- def __contains__(self, key: Sequence[int]) -> bool:
- raise NotImplementedError
-
- @abstractmethod
- def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None:
- raise NotImplementedError
-
-
-class LlamaRAMCache(BaseLlamaCache):
- """Cache for a llama.cpp model using RAM."""
-
- def __init__(self, capacity_bytes: int = (2 << 30)):
- super().__init__(capacity_bytes)
- self.capacity_bytes = capacity_bytes
- self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
-
- @property
- def cache_size(self):
- return sum([state.llama_state_size for state in self.cache_state.values()])
-
- def _find_longest_prefix_key(
- self,
- key: Tuple[int, ...],
- ) -> Optional[Tuple[int, ...]]:
- min_len = 0
- min_key = None
- keys = (
- (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
- )
- for k, prefix_len in keys:
- if prefix_len > min_len:
- min_len = prefix_len
- min_key = k
- return min_key
-
- def __getitem__(self, key: Sequence[int]) -> "LlamaState":
- key = tuple(key)
- _key = self._find_longest_prefix_key(key)
- if _key is None:
- raise KeyError("Key not found")
- value = self.cache_state[_key]
- self.cache_state.move_to_end(_key)
- return value
-
- def __contains__(self, key: Sequence[int]) -> bool:
- return self._find_longest_prefix_key(tuple(key)) is not None
-
- def __setitem__(self, key: Sequence[int], value: "LlamaState"):
- key = tuple(key)
- if key in self.cache_state:
- del self.cache_state[key]
- self.cache_state[key] = value
- while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
- self.cache_state.popitem(last=False)
-
-
-# Alias for backwards compatibility
-LlamaCache = LlamaRAMCache
-
-
-class LlamaDiskCache(BaseLlamaCache):
- """Cache for a llama.cpp model using disk."""
-
- def __init__(
- self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
- ):
- super().__init__(capacity_bytes)
- self.cache = diskcache.Cache(cache_dir)
-
- @property
- def cache_size(self):
- return int(self.cache.volume()) # type: ignore
-
- def _find_longest_prefix_key(
- self,
- key: Tuple[int, ...],
- ) -> Optional[Tuple[int, ...]]:
- min_len = 0
- min_key: Optional[Tuple[int, ...]] = None
- for k in self.cache.iterkeys(): # type: ignore
- prefix_len = Llama.longest_token_prefix(k, key)
- if prefix_len > min_len:
- min_len = prefix_len
- min_key = k # type: ignore
- return min_key
-
- def __getitem__(self, key: Sequence[int]) -> "LlamaState":
- key = tuple(key)
- _key = self._find_longest_prefix_key(key)
- if _key is None:
- raise KeyError("Key not found")
- value: "LlamaState" = self.cache.pop(_key) # type: ignore
- # NOTE: This puts an integer as key in cache, which breaks,
- # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
- # self.cache.push(_key, side="front") # type: ignore
- return value
-
- def __contains__(self, key: Sequence[int]) -> bool:
- return self._find_longest_prefix_key(tuple(key)) is not None
-
- def __setitem__(self, key: Sequence[int], value: "LlamaState"):
- print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
- key = tuple(key)
- if key in self.cache:
- print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
- del self.cache[key]
- self.cache[key] = value
- print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
- while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
- key_to_remove = next(iter(self.cache))
- del self.cache[key_to_remove]
- print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
-
-
-class LlamaState:
- def __init__(
- self,
- input_ids: npt.NDArray[np.intc],
- scores: npt.NDArray[np.single],
- n_tokens: int,
- llama_state: bytes,
- llama_state_size: int,
- ):
- self.input_ids = input_ids
- self.scores = scores
- self.n_tokens = n_tokens
- self.llama_state = llama_state
- self.llama_state_size = llama_state_size
-
-
-LogitsProcessor = Callable[
- [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
-]
-
-
-class LogitsProcessorList(List[LogitsProcessor]):
- def __call__(
- self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
- ) -> npt.NDArray[np.single]:
- for processor in self:
- scores = processor(input_ids, scores)
- return scores
-
-
-StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
-
-
-class StoppingCriteriaList(List[StoppingCriteria]):
- def __call__(
- self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
- ) -> bool:
- return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
-
-
-class _LlamaModel:
- """Intermediate Python wrapper for a llama.cpp llama_model.
-
- NOTE: For stability it's recommended you use the Llama class instead."""
-
- _llama_free_model = None
- # NOTE: this must be "saved" here to avoid exceptions when calling __del__
- suppress_stdout_stderr = suppress_stdout_stderr
-
- def __init__(
- self,
- *,
- path_model: str,
- params: llama_cpp.llama_model_params,
- verbose: bool = True,
- ):
- self.path_model = path_model
- self.params = params
- self.verbose = verbose
-
- self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
-
- if not os.path.exists(path_model):
- raise ValueError(f"Model path does not exist: {path_model}")
-
- with suppress_stdout_stderr(disable=self.verbose):
- self.model = llama_cpp.llama_load_model_from_file(
- self.path_model.encode("utf-8"), self.params
- )
-
- def __del__(self):
- with self.suppress_stdout_stderr(disable=self.verbose):
- if self.model is not None and self._llama_free_model is not None:
- self._llama_free_model(self.model)
- self.model = None
-
- def vocab_type(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_vocab_type(self.model)
-
- def n_vocab(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_n_vocab(self.model)
-
- def n_ctx_train(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_n_ctx_train(self.model)
-
- def n_embd(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_n_embd(self.model)
-
- def rope_freq_scale_train(self) -> float:
- assert self.model is not None
- return llama_cpp.llama_rope_freq_scale_train(self.model)
-
- def desc(self) -> str:
- assert self.model is not None
- buf = ctypes.create_string_buffer(1024)
- llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
- return buf.value.decode("utf-8")
-
- def size(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_model_size(self.model)
-
- def n_params(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_model_n_params(self.model)
-
- def get_tensor(self, name: str) -> ctypes.c_void_p:
- assert self.model is not None
- return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
-
- def apply_lora_from_file(
- self,
- lora_path: str,
- scale: float,
- path_base_model: Optional[str],
- n_threads: int,
- ):
- assert self.model is not None
- return llama_cpp.llama_model_apply_lora_from_file(
- self.model,
- lora_path.encode("utf-8"),
- scale,
- path_base_model.encode("utf-8")
- if path_base_model is not None
- else llama_cpp.c_char_p(0),
- n_threads,
- )
-
- # Vocab
-
- def token_get_text(self, token: int) -> str:
- # TODO: Fix
- assert self.model is not None
- return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
-
- def token_get_score(self, token: int) -> float:
- assert self.model is not None
- return llama_cpp.llama_token_get_score(self.model, token)
-
- def token_get_type(self, token: int) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_get_type(self.model, token)
-
- # Special tokens
-
- def token_bos(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_bos(self.model)
-
- def token_eos(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_eos(self.model)
-
- def token_nl(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_nl(self.model)
-
- def token_prefix(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_prefix(self.model)
-
- def token_middle(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_middle(self.model)
-
- def token_suffix(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_suffix(self.model)
-
- def token_eot(self) -> int:
- assert self.model is not None
- return llama_cpp.llama_token_eot(self.model)
-
- # Tokenization
-
- def tokenize(self, text: bytes, add_bos: bool, special: bool):
- assert self.model is not None
- n_ctx = self.n_ctx_train()
- tokens = (llama_cpp.llama_token * n_ctx)()
- n_tokens = llama_cpp.llama_tokenize(
- self.model, text, len(text), tokens, n_ctx, add_bos, special
- )
- if n_tokens < 0:
- n_tokens = abs(n_tokens)
- tokens = (llama_cpp.llama_token * n_tokens)()
- n_tokens = llama_cpp.llama_tokenize(
- self.model, text, len(text), tokens, n_tokens, add_bos, special
- )
- if n_tokens < 0:
- raise RuntimeError(
- f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
- )
- return list(tokens[:n_tokens])
-
- def token_to_piece(self, token: int) -> bytes:
- assert self.model is not None
- buf = ctypes.create_string_buffer(32)
- llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
- return bytes(buf)
-
- def detokenize(self, tokens: List[int]) -> bytes:
- assert self.model is not None
- output = b""
- size = 32
- buffer = (ctypes.c_char * size)()
- for token in tokens:
- n = llama_cpp.llama_token_to_piece(
- self.model, llama_cpp.llama_token(token), buffer, size
- )
- assert n <= size
- output += bytes(buffer[:n])
- # NOTE: Llama1 models automatically added a space at the start of the prompt
- # this line removes a leading space if the first token is a beginning of sentence token
- return (
- output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output
- )
-
- @staticmethod
- def default_params():
- """Get the default llama_model_params."""
- return llama_cpp.llama_model_default_params()
-
-
-class _LlamaContext:
- """Intermediate Python wrapper for a llama.cpp llama_context.
-
- NOTE: For stability it's recommended you use the Llama class instead."""
-
- _llama_free = None
- # NOTE: this must be "saved" here to avoid exceptions when calling __del__
- suppress_stdout_stderr = suppress_stdout_stderr
-
- def __init__(
- self,
- *,
- model: _LlamaModel,
- params: llama_cpp.llama_context_params,
- verbose: bool = True,
- ):
- self.model = model
- self.params = params
- self.verbose = verbose
-
- self._llama_free = llama_cpp._lib.llama_free # type: ignore
-
- with suppress_stdout_stderr(disable=self.verbose):
- self.ctx = llama_cpp.llama_new_context_with_model(
- self.model.model, self.params
- )
-
- def __del__(self):
- with self.suppress_stdout_stderr(disable=self.verbose):
- if self.ctx is not None and self._llama_free is not None:
- self._llama_free(self.ctx)
- self.ctx = None
-
- def n_ctx(self) -> int:
- assert self.ctx is not None
- return llama_cpp.llama_n_ctx(self.ctx)
-
- def kv_cache_clear(self):
- assert self.ctx is not None
- llama_cpp.llama_kv_cache_clear(self.ctx)
-
- def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
- assert self.ctx is not None
- llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
-
- def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
- assert self.ctx is not None
- llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
-
- def kv_cache_seq_keep(self, seq_id: int):
- assert self.ctx is not None
- llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
-
- def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
- assert self.ctx is not None
- llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift)
-
- def get_state_size(self) -> int:
- assert self.ctx is not None
- return llama_cpp.llama_get_state_size(self.ctx)
-
- # TODO: copy_state_data
-
- # TODO: set_state_data
-
- # TODO: llama_load_session_file
-
- # TODO: llama_save_session_file
-
- def decode(self, batch: "_LlamaBatch"):
- assert self.ctx is not None
- assert batch.batch is not None
- return_code = llama_cpp.llama_decode(
- ctx=self.ctx,
- batch=batch.batch,
- )
- if return_code != 0:
- raise RuntimeError(f"llama_decode returned {return_code}")
-
- def set_n_threads(self, n_threads: int, n_threads_batch: int):
- assert self.ctx is not None
- llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
-
- def get_logits(self):
- assert self.ctx is not None
- return llama_cpp.llama_get_logits(self.ctx)
-
- def get_logits_ith(self, i: int):
- assert self.ctx is not None
- return llama_cpp.llama_get_logits_ith(self.ctx, i)
-
- def get_embeddings(self):
- assert self.ctx is not None
- return llama_cpp.llama_get_embeddings(self.ctx)
-
- # Sampling functions
-
- def set_rng_seed(self, seed: int):
- assert self.ctx is not None
- llama_cpp.llama_set_rng_seed(self.ctx, seed)
-
- def sample_repetition_penalties(
- self,
- candidates: "_LlamaTokenDataArray",
- last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
- penalty_last_n: int,
- penalty_repeat: float,
- penalty_freq: float,
- penalty_present: float,
- ):
- assert self.ctx is not None
- llama_cpp.llama_sample_repetition_penalties(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- last_tokens_data,
- penalty_last_n,
- penalty_repeat,
- penalty_freq,
- penalty_present,
- )
-
- def sample_classifier_free_guidance(
- self,
- candidates: "_LlamaTokenDataArray",
- guidance_ctx: "_LlamaContext",
- scale: float,
- ):
- assert self.ctx is not None
- assert guidance_ctx.ctx is not None
- llama_cpp.llama_sample_classifier_free_guidance(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- guidance_ctx.ctx,
- scale,
- )
-
- def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
- assert self.ctx is not None
- llama_cpp.llama_sample_softmax(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- )
-
- def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
- assert self.ctx is not None
- llama_cpp.llama_sample_top_k(
- self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
- )
-
- def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
- assert self.ctx is not None
- llama_cpp.llama_sample_top_p(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
- )
-
- def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
- assert self.ctx is not None
- llama_cpp.llama_sample_min_p(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
- )
-
- def sample_tail_free(
- self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
- ):
- assert self.ctx is not None
- llama_cpp.llama_sample_tail_free(
- self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
- )
-
- def sample_typical(
- self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
- ):
- assert self.ctx is not None
- llama_cpp.llama_sample_typical(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
- )
-
- def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
- assert self.ctx is not None
- llama_cpp.llama_sample_temp(
- self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
- )
-
- def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
- assert self.ctx is not None
- assert grammar.grammar is not None
- llama_cpp.llama_sample_grammar(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- grammar.grammar,
- )
-
- def sample_token_mirostat(
- self,
- candidates: "_LlamaTokenDataArray",
- tau: float,
- eta: float,
- m: int,
- mu: float,
- ) -> int:
- assert self.ctx is not None
- return llama_cpp.llama_sample_token_mirostat(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- tau,
- eta,
- m,
- ctypes.pointer(ctypes.c_float(mu)),
- )
-
- def sample_token_mirostat_v2(
- self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float
- ) -> int:
- assert self.ctx is not None
- return llama_cpp.llama_sample_token_mirostat_v2(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- tau,
- eta,
- ctypes.pointer(ctypes.c_float(mu)),
- )
-
- def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
- assert self.ctx is not None
- return llama_cpp.llama_sample_token_greedy(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- )
-
- def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
- assert self.ctx is not None
- return llama_cpp.llama_sample_token(
- self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
- )
-
- # Grammar
- def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
- assert self.ctx is not None
- assert grammar.grammar is not None
- llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token)
-
- def reset_timings(self):
- assert self.ctx is not None
- llama_cpp.llama_reset_timings(self.ctx)
-
- def print_timings(self):
- assert self.ctx is not None
- llama_cpp.llama_print_timings(self.ctx)
-
- # Utility functions
- @staticmethod
- def default_params():
- """Get the default llama_context_params."""
- return llama_cpp.llama_context_default_params()
-
-
-class _LlamaBatch:
- _llama_batch_free = None
- # NOTE: this must be "saved" here to avoid exceptions when calling __del__
- suppress_stdout_stderr = suppress_stdout_stderr
-
- def __init__(
- self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
- ):
- self.n_tokens = n_tokens
- self.embd = embd
- self.n_seq_max = n_seq_max
- self.verbose = verbose
-
- self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
-
- with suppress_stdout_stderr(disable=self.verbose):
- self.batch = llama_cpp.llama_batch_init(
- self.n_tokens, self.embd, self.n_seq_max
- )
-
- def __del__(self):
- with self.suppress_stdout_stderr(disable=self.verbose):
- if self.batch is not None and self._llama_batch_free is not None:
- self._llama_batch_free(self.batch)
- self.batch = None
-
- def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
- assert self.batch is not None
- n_tokens = len(batch)
- self.batch.n_tokens = n_tokens
- for i in range(n_tokens):
- self.batch.token[i] = batch[i]
- self.batch.pos[i] = n_past + i
- self.batch.seq_id[i][0] = 0
- self.batch.n_seq_id[i] = 1
- self.batch.logits[i] = logits_all
- self.batch.logits[n_tokens - 1] = True
-
-
-class _LlamaTokenDataArray:
- def __init__(self, *, n_vocab: int):
- self.n_vocab = n_vocab
- self.candidates_data = np.array(
- [],
- dtype=np.dtype(
- [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
- ),
- )
- self.candidates_data.resize(3, self.n_vocab, refcheck=False)
- self.candidates = llama_cpp.llama_token_data_array(
- data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
- size=self.n_vocab,
- sorted=False,
- )
- self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
- self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
-
- def copy_logits(self, logits: npt.NDArray[np.single]):
- self.candidates_data["id"][:] = self.default_candidates_data_id
- self.candidates_data["logit"][:] = logits
- self.candidates_data["p"][:] = self.default_candidates_data_p
- self.candidates.data = self.candidates_data.ctypes.data_as(
- llama_cpp.llama_token_data_p
- )
- self.candidates.sorted = llama_cpp.c_bool(False)
- self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
+from ._internals import (
+ _LlamaModel, # type: ignore
+ _LlamaContext, # type: ignore
+ _LlamaBatch, # type: ignore
+ _LlamaTokenDataArray, # type: ignore
+)
class Llama:
@@ -754,7 +77,7 @@ class Llama:
mul_mat_q: bool = True,
logits_all: bool = False,
embedding: bool = False,
- offload_kqv: bool = False,
+ offload_kqv: bool = True,
# Sampling Params
last_n_tokens_size: int = 64,
# LoRA Params
@@ -1006,6 +329,18 @@ class Llama:
(n_ctx, self._n_vocab), dtype=np.single
)
+ self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
+
+ try:
+ self.metadata = self._model.metadata()
+ except Exception as e:
+ self.metadata = {}
+ if self.verbose:
+ print(f"Failed to load metadata: {e}", file=sys.stderr)
+
+ if self.verbose:
+ print(f"Model metadata: {self.metadata}", file=sys.stderr)
+
@property
def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None
@@ -1193,7 +528,7 @@ class Llama:
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
- mu=2.0 * mirostat_tau,
+ mu=ctypes.pointer(self._mirostat_mu),
m=100,
)
elif mirostat_mode == 2:
@@ -1202,7 +537,7 @@ class Llama:
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
- mu=2.0 * mirostat_tau,
+ mu=ctypes.pointer(self._mirostat_mu)
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
@@ -1258,6 +593,10 @@ class Llama:
Yields:
The generated tokens.
"""
+ # Reset mirostat sampling
+ self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
+
+ # Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
@@ -1272,12 +611,15 @@ class Llama:
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
+ # Reset the model state
if reset:
self.reset()
+ # Reset the grammar
if grammar is not None:
grammar.reset()
+ # Eval and sample
while True:
self.eval(tokens)
token = self.sample(
@@ -2372,3 +1714,43 @@ class LlamaTokenizer:
@classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(Llama(model_path=path, vocab_only=True))
+
+
+class LlamaState:
+ def __init__(
+ self,
+ input_ids: npt.NDArray[np.intc],
+ scores: npt.NDArray[np.single],
+ n_tokens: int,
+ llama_state: bytes,
+ llama_state_size: int,
+ ):
+ self.input_ids = input_ids
+ self.scores = scores
+ self.n_tokens = n_tokens
+ self.llama_state = llama_state
+ self.llama_state_size = llama_state_size
+
+
+LogitsProcessor = Callable[
+ [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
+]
+
+
+class LogitsProcessorList(List[LogitsProcessor]):
+ def __call__(
+ self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
+ ) -> npt.NDArray[np.single]:
+ for processor in self:
+ scores = processor(input_ids, scores)
+ return scores
+
+
+StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
+
+
+class StoppingCriteriaList(List[StoppingCriteria]):
+ def __call__(
+ self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
+ ) -> bool:
+ return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py
new file mode 100644
index 0000000..9e9870a
--- /dev/null
+++ b/llama_cpp/llama_cache.py
@@ -0,0 +1,150 @@
+import sys
+from abc import ABC, abstractmethod
+from typing import (
+ Optional,
+ Sequence,
+ Tuple,
+)
+from collections import OrderedDict
+
+import diskcache
+
+import llama_cpp.llama
+
+from .llama_types import *
+
+
+class BaseLlamaCache(ABC):
+ """Base cache class for a llama.cpp model."""
+
+ def __init__(self, capacity_bytes: int = (2 << 30)):
+ self.capacity_bytes = capacity_bytes
+
+ @property
+ @abstractmethod
+ def cache_size(self) -> int:
+ raise NotImplementedError
+
+ def _find_longest_prefix_key(
+ self,
+ key: Tuple[int, ...],
+ ) -> Optional[Tuple[int, ...]]:
+ pass
+
+ @abstractmethod
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
+ raise NotImplementedError
+
+ @abstractmethod
+ def __contains__(self, key: Sequence[int]) -> bool:
+ raise NotImplementedError
+
+ @abstractmethod
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None:
+ raise NotImplementedError
+
+
+class LlamaRAMCache(BaseLlamaCache):
+ """Cache for a llama.cpp model using RAM."""
+
+ def __init__(self, capacity_bytes: int = (2 << 30)):
+ super().__init__(capacity_bytes)
+ self.capacity_bytes = capacity_bytes
+ self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict()
+
+ @property
+ def cache_size(self):
+ return sum([state.llama_state_size for state in self.cache_state.values()])
+
+ def _find_longest_prefix_key(
+ self,
+ key: Tuple[int, ...],
+ ) -> Optional[Tuple[int, ...]]:
+ min_len = 0
+ min_key = None
+ keys = (
+ (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
+ )
+ for k, prefix_len in keys:
+ if prefix_len > min_len:
+ min_len = prefix_len
+ min_key = k
+ return min_key
+
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
+ key = tuple(key)
+ _key = self._find_longest_prefix_key(key)
+ if _key is None:
+ raise KeyError("Key not found")
+ value = self.cache_state[_key]
+ self.cache_state.move_to_end(_key)
+ return value
+
+ def __contains__(self, key: Sequence[int]) -> bool:
+ return self._find_longest_prefix_key(tuple(key)) is not None
+
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
+ key = tuple(key)
+ if key in self.cache_state:
+ del self.cache_state[key]
+ self.cache_state[key] = value
+ while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
+ self.cache_state.popitem(last=False)
+
+
+# Alias for backwards compatibility
+LlamaCache = LlamaRAMCache
+
+
+class LlamaDiskCache(BaseLlamaCache):
+ """Cache for a llama.cpp model using disk."""
+
+ def __init__(
+ self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
+ ):
+ super().__init__(capacity_bytes)
+ self.cache = diskcache.Cache(cache_dir)
+
+ @property
+ def cache_size(self):
+ return int(self.cache.volume()) # type: ignore
+
+ def _find_longest_prefix_key(
+ self,
+ key: Tuple[int, ...],
+ ) -> Optional[Tuple[int, ...]]:
+ min_len = 0
+ min_key: Optional[Tuple[int, ...]] = None
+ for k in self.cache.iterkeys(): # type: ignore
+ prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
+ if prefix_len > min_len:
+ min_len = prefix_len
+ min_key = k # type: ignore
+ return min_key
+
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
+ key = tuple(key)
+ _key = self._find_longest_prefix_key(key)
+ if _key is None:
+ raise KeyError("Key not found")
+ value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
+ # NOTE: This puts an integer as key in cache, which breaks,
+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
+ # self.cache.push(_key, side="front") # type: ignore
+ return value
+
+ def __contains__(self, key: Sequence[int]) -> bool:
+ return self._find_longest_prefix_key(tuple(key)) is not None
+
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
+ print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
+ key = tuple(key)
+ if key in self.cache:
+ print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
+ del self.cache[key]
+ self.cache[key] = value
+ print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
+ while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
+ key_to_remove = next(iter(self.cache))
+ del self.cache[key_to_remove]
+ print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py
index 0ef7bd4..02bdbcf 100644
--- a/llama_cpp/llama_chat_format.py
+++ b/llama_cpp/llama_chat_format.py
@@ -6,18 +6,28 @@ import ctypes
import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
+import jinja2
+
import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
-from ._utils import suppress_stdout_stderr
+from ._utils import suppress_stdout_stderr, Singleton
class LlamaChatCompletionHandler(Protocol):
+ """Base Protocol for a llama chat completion handler.
+
+ Very generic protocol that can be used to implement any chat format.
+ The only hard requirement is that it must return a ChatCompletion when
+ stream=False and an iterator of ChatCompletionChunks when stream=True."""
+
def __call__(
self,
*,
+ # llama.cpp instance
llama: llama.Llama,
+ # openai api parameters
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
@@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
- min_p: float = 0.05,
- typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
@@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
+ model: Optional[str] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ # llama.cpp parameters
+ min_p: float = 0.05,
+ typical_p: float = 1.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
- model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
- logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@@ -54,146 +65,78 @@ class LlamaChatCompletionHandler(Protocol):
...
-CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
+class LlamaChatCompletionHandlerNotFoundException(Exception):
+ pass
+
+
+class LlamaChatCompletionHandlerRegistry(Singleton):
+ _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
+
+ def register_chat_completion_handler(
+ self,
+ name: str,
+ chat_handler: LlamaChatCompletionHandler,
+ overwrite: bool = False,
+ ):
+ if not overwrite and name in self._chat_handlers:
+ raise ValueError(
+ f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
+ )
+ self._chat_handlers[name] = chat_handler
+
+ def unregister_chat_handler(self, name: str):
+ if name in self._chat_handlers:
+ del self._chat_handlers[name]
+ else:
+ raise ValueError(f"No formatter registered under the name '{name}'.")
+
+ def get_chat_completion_handler_by_name(
+ self, name: str
+ ) -> LlamaChatCompletionHandler:
+ try:
+ chat_handler = self._chat_handlers[name]
+ return chat_handler
+ except KeyError:
+ raise LlamaChatCompletionHandlerNotFoundException(
+ f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
+ )
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
- return CHAT_HANDLERS[name]
+ return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
+ name
+ )
def register_chat_completion_handler(name: str):
def decorator(f: LlamaChatCompletionHandler):
- CHAT_HANDLERS[name] = f
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
return f
return decorator
-def _get_system_message(
- messages: List[llama_types.ChatCompletionRequestMessage],
-) -> str:
- """Get the first system message."""
- for message in messages:
- if message["role"] == "system":
- return message["content"] or ""
- return ""
-
-
-def _map_roles(
- messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
-) -> List[Tuple[str, Optional[str]]]:
- """Map the message roles."""
- output: List[Tuple[str, Optional[str]]] = []
- for message in messages:
- role = message["role"]
- if role in role_map:
- output.append((role_map[role], message["content"]))
- return output
-
-
-def _format_llama2(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
-) -> str:
- """Format the prompt with the llama2 style."""
- seps = [sep, sep2]
- ret = system_message + sep
- for i, (role, message) in enumerate(messages):
- if system_message and i == 0:
- ret += message + seps[i % 2]
- elif message:
- ret += role + message + " " + seps[i % 2]
- else:
- ret += role + " "
- return ret
-
-
-def _format_add_colon_single(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
-) -> str:
- """Format the prompt with the add-colon-single style."""
- ret = system_message + sep
- for role, message in messages:
- if message:
- ret += role + ": " + message + sep
- else:
- ret += role + ":"
- return ret
-
-
-def _format_add_colon_two(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
-) -> str:
- """Format the prompt with the add-colon-two style."""
- seps = [sep, sep2]
- ret = system_message + seps[0]
- for i, (role, message) in enumerate(messages):
- if message:
- ret += role + ": " + message + seps[i % 2]
- else:
- ret += role + ":"
- return ret
-
-
-def _format_no_colon_single(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
-) -> str:
- """Format the prompt with the no-colon-single style."""
- ret = system_message
- for role, message in messages:
- if message:
- ret += role + message + sep
- else:
- ret += role
- return ret
-
-
-def _format_add_colon_space_single(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
-) -> str:
- """Format the prompt with the add-colon-space-single style."""
- ret = system_message + sep
- for role, message in messages:
- if message:
- ret += role + ": " + message + sep
- else:
- ret += role + ": " # must be end with a space
- return ret
-
-
-def _format_chatml(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
-) -> str:
- """Format the prompt with the chatml style."""
- ret = "" if system_message == "" else system_message + sep + "\n"
- for role, message in messages:
- if message:
- ret += role + "\n" + message + sep + "\n"
- else:
- ret += role + "\n"
- return ret
-
-def _format_chatglm3(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
-) -> str:
- """Format the prompt with the chatglm3 style."""
- ret = ""
- if system_message:
- ret += system_message
- for role, message in messages:
- if message:
- ret += role + "\n" + " " + message
- else:
- ret += role
- return ret
+### Chat Formatter ###
@dataclasses.dataclass
class ChatFormatterResponse:
+ """Dataclass that stores completion parameters for a given chat format and
+ create_chat_completion request.
+
+ prompt contains the formatted prompt generated from the chat format and messages.
+ stop contains the stop token or list of stop tokens to use for the chat format."""
+
prompt: str
stop: Optional[Union[str, List[str]]] = None
class ChatFormatter(Protocol):
+ """Base Protocol for a chat formatter. A chat formatter is a function that
+ takes a list of messages and returns a chat format response which can be used
+ to generate a completion. The response can also include a stop token or list
+ of stop tokens to use for the completion."""
+
def __call__(
self,
*,
@@ -203,14 +146,52 @@ class ChatFormatter(Protocol):
...
-class BasicChatHandler:
- def __init__(self, chat_format: str):
- self.chat_format = chat_format
+class Jinja2ChatFormatter(ChatFormatter):
+ def __init__(
+ self,
+ template: str,
+ eos_token: str,
+ bos_token: str,
+ add_generation_prompt: bool = True,
+ ):
+ """A chat formatter that uses jinja2 templates to format the prompt."""
+ self.template = template
+ self.eos_token = eos_token
+ self.bos_token = bos_token
+ self.add_generation_prompt = add_generation_prompt
+
+ self._environment = jinja2.Environment(
+ loader=jinja2.BaseLoader(),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ ).from_string(self.template)
+
+ def __call__(
+ self,
+ *,
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ **kwargs: Any,
+ ) -> ChatFormatterResponse:
+ if self.add_generation_prompt:
+ messages = [
+ *messages,
+ llama_types.ChatCompletionRequestAssistantMessage(
+ role="assistant", content=""
+ ),
+ ]
+ prompt = self._environment.render(
+ messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
+ )
+ return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
+
+ def to_chat_handler(self) -> LlamaChatCompletionHandler:
+ return chat_formatter_to_chat_completion_handler(self)
def _convert_text_completion_to_chat(
completion: llama_types.Completion,
) -> llama_types.ChatCompletion:
+ assert "usage" in completion
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
@@ -286,103 +267,85 @@ def _convert_completion_to_chat(
return _convert_text_completion_to_chat(completion)
-_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
-
-
-def register_chat_format(name: str):
- def decorator(f: ChatFormatter):
- def basic_create_chat_completion(
- *,
- llama: llama.Llama,
- messages: List[llama_types.ChatCompletionRequestMessage],
- functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
- function_call: Optional[
- llama_types.ChatCompletionRequestFunctionCall
- ] = None,
- tools: Optional[List[llama_types.ChatCompletionTool]] = None,
- tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
- temperature: float = 0.2,
- top_p: float = 0.95,
- top_k: int = 40,
- min_p: float = 0.05,
- typical_p: float = 1.0,
- stream: bool = False,
- stop: Optional[Union[str, List[str]]] = [],
- seed: Optional[int] = None,
- response_format: Optional[
- llama_types.ChatCompletionRequestResponseFormat
- ] = None,
- max_tokens: Optional[int] = None,
- presence_penalty: float = 0.0,
- frequency_penalty: float = 0.0,
- repeat_penalty: float = 1.1,
- tfs_z: float = 1.0,
- mirostat_mode: int = 0,
- mirostat_tau: float = 5.0,
- mirostat_eta: float = 0.1,
- model: Optional[str] = None,
- logits_processor: Optional[llama.LogitsProcessorList] = None,
- grammar: Optional[llama.LlamaGrammar] = None,
- logit_bias: Optional[Dict[str, float]] = None,
- **kwargs, # type: ignore
- ) -> Union[
- llama_types.CreateChatCompletionResponse,
- Iterator[llama_types.CreateChatCompletionStreamResponse],
- ]:
- result = f(
- messages=messages,
- functions=functions,
- function_call=function_call,
- )
- prompt = result.prompt
- if result.stop is not None:
- stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
- rstop = result.stop if isinstance(result.stop, list) else [result.stop]
- stop = stop + rstop
-
- if response_format is not None and response_format["type"] == "json_object":
- grammar = llama_grammar.LlamaGrammar.from_string(
- llama_grammar.JSON_GBNF
- )
-
- completion_or_chunks = llama.create_completion(
- prompt=prompt,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- min_p=min_p,
- typical_p=typical_p,
- stream=stream,
- stop=stop,
- seed=seed,
- max_tokens=max_tokens,
- presence_penalty=presence_penalty,
- frequency_penalty=frequency_penalty,
- repeat_penalty=repeat_penalty,
- tfs_z=tfs_z,
- mirostat_mode=mirostat_mode,
- mirostat_tau=mirostat_tau,
- mirostat_eta=mirostat_eta,
- model=model,
- logits_processor=logits_processor,
- grammar=grammar,
- logit_bias=logit_bias,
- )
- return _convert_completion_to_chat(completion_or_chunks, stream=stream)
-
- register_chat_completion_handler(name)(basic_create_chat_completion)
- return f
-
- return decorator
-
-
-def get_chat_format(name: str):
- try:
- return _CHAT_FORMATS[name]
- except KeyError:
- raise ValueError(
- f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
+def chat_formatter_to_chat_completion_handler(
+ chat_formatter: ChatFormatter,
+) -> LlamaChatCompletionHandler:
+ def chat_completion_handler(
+ *,
+ llama: llama.Llama,
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
+ temperature: float = 0.2,
+ top_p: float = 0.95,
+ top_k: int = 40,
+ min_p: float = 0.05,
+ typical_p: float = 1.0,
+ stream: bool = False,
+ stop: Optional[Union[str, List[str]]] = [],
+ seed: Optional[int] = None,
+ response_format: Optional[
+ llama_types.ChatCompletionRequestResponseFormat
+ ] = None,
+ max_tokens: Optional[int] = None,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ repeat_penalty: float = 1.1,
+ tfs_z: float = 1.0,
+ mirostat_mode: int = 0,
+ mirostat_tau: float = 5.0,
+ mirostat_eta: float = 0.1,
+ model: Optional[str] = None,
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
+ grammar: Optional[llama.LlamaGrammar] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ **kwargs, # type: ignore
+ ) -> Union[
+ llama_types.CreateChatCompletionResponse,
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
+ ]:
+ result = chat_formatter(
+ messages=messages,
+ functions=functions,
+ function_call=function_call,
)
+ prompt = result.prompt
+ if result.stop is not None:
+ stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
+ rstop = result.stop if isinstance(result.stop, list) else [result.stop]
+ stop = stop + rstop
+
+ if response_format is not None and response_format["type"] == "json_object":
+ grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
+
+ completion_or_chunks = llama.create_completion(
+ prompt=prompt,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
+ typical_p=typical_p,
+ stream=stream,
+ stop=stop,
+ seed=seed,
+ max_tokens=max_tokens,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ repeat_penalty=repeat_penalty,
+ tfs_z=tfs_z,
+ mirostat_mode=mirostat_mode,
+ mirostat_tau=mirostat_tau,
+ mirostat_eta=mirostat_eta,
+ model=model,
+ logits_processor=logits_processor,
+ grammar=grammar,
+ logit_bias=logit_bias,
+ )
+ return _convert_completion_to_chat(completion_or_chunks, stream=stream)
+
+ return chat_completion_handler
def hf_autotokenizer_to_chat_formatter(
@@ -391,22 +354,222 @@ def hf_autotokenizer_to_chat_formatter(
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
- from transformers import AutoTokenizer
+ from transformers import AutoTokenizer # type: ignore
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
- tokenizer.use_default_system_prompt = False
- _prompt = tokenizer.apply_chat_template(messages, tokenize=False)
+ tokenizer.use_default_system_prompt = False # type: ignore
+ prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
+ assert isinstance(prompt, str)
# Return formatted prompt and eos token by default
- return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
+ return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
return format_autotokenizer
+def hf_autotokenizer_to_chat_completion_handler(
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]]
+) -> LlamaChatCompletionHandler:
+ chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
+
+
+def hf_tokenizer_config_to_chat_formatter(
+ tokenizer_config: Dict[str, Any]
+) -> ChatFormatter:
+ assert isinstance(tokenizer_config, dict)
+
+ assert "chat_template" in tokenizer_config
+ assert isinstance(tokenizer_config["chat_template"], str)
+ chat_template = tokenizer_config["chat_template"]
+
+ assert "bos_token" in tokenizer_config
+ assert isinstance(tokenizer_config["bos_token"], str)
+ bos_token = tokenizer_config["bos_token"]
+
+ assert "eos_token" in tokenizer_config
+ assert isinstance(tokenizer_config["eos_token"], str)
+ eos_token = tokenizer_config["eos_token"]
+
+ env = jinja2.Environment(
+ loader=jinja2.BaseLoader(),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ ).from_string(chat_template)
+
+ def format_autotokenizer(
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ **kwargs: Any,
+ ) -> ChatFormatterResponse:
+ # TODO: veryify this is correct
+ # Add a blank assistant message to the end of the messages to prompt the model to generate a response
+ prompt = env.render(
+ messages=[
+ *messages,
+ llama_types.ChatCompletionRequestAssistantMessage(
+ role="assistant", content=""
+ ),
+ ],
+ bos_token=bos_token,
+ eos_token=eos_token,
+ )
+ return ChatFormatterResponse(prompt=prompt, stop=eos_token)
+
+ return format_autotokenizer
+
+
+def hf_tokenizer_config_to_chat_completion_handler(
+ tokenizer_config: Dict[str, Any],
+) -> LlamaChatCompletionHandler:
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
+
+
+### Utility functions for formatting chat prompts ###
+
+
+def _get_system_message(
+ messages: List[llama_types.ChatCompletionRequestMessage],
+) -> str:
+ """Get the first system message."""
+ for message in messages:
+ if message["role"] == "system":
+ return message["content"] or ""
+ return ""
+
+
+def _map_roles(
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ role_map: Dict[str, str],
+) -> List[Tuple[str, Optional[str]]]:
+ """Map the message roles."""
+ output: List[Tuple[str, Optional[str]]] = []
+ for message in messages:
+ role = message["role"]
+ if role in role_map:
+ content: str | None = (
+ message["content"] if isinstance(message["content"], str) else None
+ )
+ output.append((role_map[role], content))
+ return output
+
+
+def _format_llama2(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
+) -> str:
+ """Format the prompt with the llama2 style."""
+ seps = [sep, sep2]
+ ret = system_message + sep
+ for i, (role, message) in enumerate(messages):
+ if system_message and i == 0:
+ m = message or ""
+ ret += m + seps[i % 2]
+ elif message:
+ ret += role + message + " " + seps[i % 2]
+ else:
+ ret += role + " "
+ return ret
+
+
+def _format_add_colon_single(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+) -> str:
+ """Format the prompt with the add-colon-single style."""
+ ret = system_message + sep
+ for role, message in messages:
+ if message:
+ ret += role + ": " + message + sep
+ else:
+ ret += role + ":"
+ return ret
+
+
+def _format_add_colon_two(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
+) -> str:
+ """Format the prompt with the add-colon-two style."""
+ seps = [sep, sep2]
+ ret = system_message + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+
+
+def _format_no_colon_single(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+) -> str:
+ """Format the prompt with the no-colon-single style."""
+ ret = system_message
+ for role, message in messages:
+ if message:
+ ret += role + message + sep
+ else:
+ ret += role
+ return ret
+
+
+def _format_add_colon_space_single(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+) -> str:
+ """Format the prompt with the add-colon-space-single style."""
+ ret = system_message + sep
+ for role, message in messages:
+ if message:
+ ret += role + ": " + message + sep
+ else:
+ ret += role + ": " # must be end with a space
+ return ret
+
+
+def _format_chatml(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+) -> str:
+ """Format the prompt with the chatml style."""
+ ret = "" if system_message == "" else system_message + sep + "\n"
+ for role, message in messages:
+ if message:
+ ret += role + "\n" + message + sep + "\n"
+ else:
+ ret += role + "\n"
+ return ret
+
+
+def _format_chatglm3(
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+) -> str:
+ """Format the prompt with the chatglm3 style."""
+ ret = ""
+ if system_message:
+ ret += system_message
+ for role, message in messages:
+ if message:
+ ret += role + "\n" + " " + message
+ else:
+ ret += role
+ return ret
+
+
+### Chat Formats ###
+
+
+def register_chat_format(name: str):
+ def decorator(f: ChatFormatter):
+ chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
+ name, chat_completion_handler
+ )
+ return f
+
+ return decorator
+
+
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
@register_chat_format("llama-2")
@@ -437,21 +600,23 @@ def format_alpaca(
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
return ChatFormatterResponse(prompt=_prompt)
+
@register_chat_format("qwen")
def format_qwen(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
- system_message="You are a helpful assistant."
- system_template="<|im_start|>system\n{system_message}"
- system_message=system_template.format(system_message=system_message)
+ system_message = "You are a helpful assistant."
+ system_template = "<|im_start|>system\n{system_message}"
+ system_message = system_template.format(system_message=system_message)
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_sep = "<|im_end|>"
_prompt = _format_chatml(system_message, _messages, _sep)
_sep2 = "<|endoftext|>"
- return ChatFormatterResponse(prompt=_prompt,stop=_sep2)
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
+
@register_chat_format("vicuna")
def format(
@@ -650,6 +815,7 @@ def format_mistrallite(
_prompt = _format_no_colon_single(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt)
+
@register_chat_format("zephyr")
def format_zephyr(
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -699,6 +865,7 @@ def format_chatml(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
+
@register_chat_format("chatglm3")
def format_chatglm3(
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -739,7 +906,7 @@ def format_openchat(
@register_chat_format("saiga")
def format_saiga(
messages: list[llama_types.ChatCompletionRequestMessage],
- **kwargs,
+ **kwargs: Any,
) -> ChatFormatterResponse:
_message_template = "{role}\n{content}"
_roles = dict(user="user", bot="bot", system="system")
diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index 9e8e3ce..ef16272 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -91,6 +91,12 @@ c_float_p = POINTER(c_float)
c_uint8_p = POINTER(c_uint8)
c_size_t_p = POINTER(c_size_t)
+# from ggml-backend.h
+# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
+ c_bool, c_void_p, c_bool, c_void_p
+)
+
# llama.h bindings
_lib.llama_max_devices.argtypes = []
@@ -448,6 +454,9 @@ class llama_model_params(Structure):
# float yarn_beta_slow; // YaRN high correction dim
# uint32_t yarn_orig_ctx; // YaRN original context size
+# ggml_backend_sched_eval_callback cb_eval;
+# void * cb_eval_user_data;
+
# enum ggml_type type_k; // data type for K cache
# enum ggml_type type_v; // data type for V cache
@@ -475,6 +484,8 @@ class llama_context_params(Structure):
yarn_beta_fast (float): YaRN low correction dim
yarn_beta_slow (float): YaRN high correction dim
yarn_orig_ctx (int): YaRN original context size
+ cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
+ cb_eval_user_data (ctypes.c_void_p): user data for cb_eval
type_k (int): data type for K cache
type_v (int): data type for V cache
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
@@ -497,6 +508,8 @@ class llama_context_params(Structure):
("yarn_beta_fast", c_float),
("yarn_beta_slow", c_float),
("yarn_orig_ctx", c_uint32),
+ ("cb_eval", ggml_backend_sched_eval_callback),
+ ("cb_eval_user_data", c_void_p),
("type_k", c_int),
("type_v", c_int),
("mul_mat_q", c_bool),
diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index c02e656..d8ef563 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -1432,7 +1432,6 @@ class SchemaConverter:
return key
def visit(self, schema: Dict[str, Any], name: str) -> str:
- schema_type: Optional[str] = schema.get("type") # type: ignore
rule_name = name or "root"
if "$defs" in schema:
@@ -1458,7 +1457,19 @@ class SchemaConverter:
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule)
- elif schema_type == "object" and "properties" in schema:
+ elif "$ref" in schema:
+ ref = schema["$ref"]
+ assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
+ # inline $defs
+ def_name = ref[len("#/$defs/") :]
+ def_schema = self._defs[def_name]
+ return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
+
+
+ schema_type: Optional[str] = schema.get("type") # type: ignore
+ assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
+
+ if schema_type == "object" and "properties" in schema:
# TODO: `required` keyword
prop_order = self._prop_order
prop_pairs = sorted(
@@ -1489,14 +1500,6 @@ class SchemaConverter:
)
return self._add_rule(rule_name, rule)
- elif "$ref" in schema:
- ref = schema["$ref"]
- assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
- # inline $defs
- def_name = ref[len("#/$defs/") :]
- def_schema = self._defs[def_name]
- return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
-
else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
return self._add_rule(
diff --git a/llama_cpp/server/cli.py b/llama_cpp/server/cli.py
index 8e32d2c..3dd0076 100644
--- a/llama_cpp/server/cli.py
+++ b/llama_cpp/server/cli.py
@@ -55,7 +55,7 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
raise ValueError(f"Invalid boolean argument: {arg}")
-def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
+def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
"""Add arguments from a pydantic model to an argparse parser."""
for name, field in model.model_fields.items():
@@ -83,7 +83,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel])
)
-T = TypeVar("T", bound=type[BaseModel])
+T = TypeVar("T", bound=Type[BaseModel])
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py
index f9be323..bbb6806 100644
--- a/llama_cpp/server/model.py
+++ b/llama_cpp/server/model.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import json
+
from typing import Dict, Optional, Union, List
import llama_cpp
@@ -71,7 +73,25 @@ class LlamaProxy:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
-
+ elif settings.chat_format == "hf-autotokenizer":
+ assert (
+ settings.hf_pretrained_model_name_or_path is not None
+ ), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
+ chat_handler = (
+ llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_completion_handler(
+ settings.hf_pretrained_model_name_or_path
+ )
+ )
+ elif settings.chat_format == "hf-tokenizer-config":
+ assert (
+ settings.hf_tokenizer_config_path is not None
+ ), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
+ chat_handler = (
+ llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
+ json.load(open(settings.hf_tokenizer_config_path))
+ )
+ )
+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
@@ -141,4 +161,3 @@ class LlamaProxy:
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
_model.set_cache(cache)
return _model
-
diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py
index a10390c..9f0dc8a 100644
--- a/llama_cpp/server/settings.py
+++ b/llama_cpp/server/settings.py
@@ -90,7 +90,7 @@ class ModelSettings(BaseSettings):
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=True, description="Whether to use embeddings.")
offload_kqv: bool = Field(
- default=False, description="Whether to offload kqv to the GPU."
+ default=True, description="Whether to offload kqv to the GPU."
)
# Sampling Params
last_n_tokens_size: int = Field(
@@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
+ # Tokenizer Options
+ hf_tokenizer_config_path: Optional[str] = Field(
+ default=None,
+ description="The path to a HuggingFace tokenizer_config.json file.",
+ )
+ hf_pretrained_model_name_or_path: Optional[str] = Field(
+ default=None,
+ description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
+ )
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
diff --git a/pyproject.toml b/pyproject.toml
index b5affaa..4130972 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,6 +15,7 @@ dependencies = [
"typing-extensions>=4.5.0",
"numpy>=1.20.0",
"diskcache>=5.6.1",
+ "jinja2>=2.11.3",
]
requires-python = ">=3.8"
classifiers = [
@@ -72,4 +73,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
[tool.pytest.ini_options]
addopts = "--ignore=vendor"
-
diff --git a/tests/test_grammar.py b/tests/test_grammar.py
index ef9392b..cb22188 100644
--- a/tests/test_grammar.py
+++ b/tests/test_grammar.py
@@ -50,3 +50,29 @@ def test_composed_pydantic_grammar():
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
assert grammar.grammar is not None
+
+
+def test_grammar_anyof():
+ sch = {
+ "properties": {
+ "temperature": {
+ "description": "The temperature mentioned",
+ "type": "number",
+ },
+ "unit": {
+ "anyOf": [
+ {
+ "description": "Unit for temperature",
+ "enum": ["celsius", "fahrenheit"],
+ "type": "string",
+ },
+ {"type": "null"},
+ ],
+ },
+ },
+ "type": "object",
+ }
+
+ grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
+
+ assert grammar.grammar is not None
\ No newline at end of file
diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py
new file mode 100644
index 0000000..1ef18d9
--- /dev/null
+++ b/tests/test_llama_chat_format.py
@@ -0,0 +1,65 @@
+import json
+
+from llama_cpp import (
+ ChatCompletionRequestUserMessage,
+)
+from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
+
+
+mistral_7b_tokenizer_config = """{
+ "add_bos_token": true,
+ "add_eos_token": false,
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ }
+ },
+ "additional_special_tokens": [],
+ "bos_token": "",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "",
+ "legacy": true,
+ "model_max_length": 1000000000000000019884624838656,
+ "pad_token": null,
+ "sp_model_kwargs": {},
+ "spaces_between_special_tokens": false,
+ "tokenizer_class": "LlamaTokenizer",
+ "unk_token": "",
+ "use_default_system_prompt": false,
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
+}"""
+
+
+def test_hf_tokenizer_config_str_to_chat_formatter():
+ tokenizer_config = json.loads(mistral_7b_tokenizer_config)
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(
+ tokenizer_config
+ )
+ chat_formatter_respoonse = chat_formatter(
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
+ ]
+ )
+
+ assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "")
diff --git a/vendor/llama.cpp b/vendor/llama.cpp
index 5c99960..504dc37 160000
--- a/vendor/llama.cpp
+++ b/vendor/llama.cpp
@@ -1 +1 @@
-Subproject commit 5c999609013a30c06e6fd28be8db5c2074bcc196
+Subproject commit 504dc37be8446fb09b1ede70300250ad41be32a2