From 3b92419132700259acb4690d0ce2e2ee979e00bc Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Jan 2024 09:09:12 -0500 Subject: [PATCH 01/26] Move cache classes to llama_cache submodule. --- llama_cpp/llama.py | 147 ++------------------------------------ llama_cpp/llama_cache.py | 150 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 140 deletions(-) create mode 100644 llama_cpp/llama_cache.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e4be9d1..c6b55ab 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -3,7 +3,6 @@ import sys import uuid import time import multiprocessing -from abc import ABC, abstractmethod from typing import ( List, Optional, @@ -12,16 +11,20 @@ from typing import ( Sequence, Iterator, Deque, - Tuple, Callable, ) -from collections import deque, OrderedDict +from collections import deque -import diskcache import ctypes from .llama_types import * from .llama_grammar import LlamaGrammar +from .llama_cache import ( + BaseLlamaCache, + LlamaCache, # type: ignore + LlamaDiskCache, # type: ignore + LlamaRAMCache, # type: ignore +) import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format @@ -31,142 +34,6 @@ import numpy.typing as npt from ._utils import suppress_stdout_stderr -class BaseLlamaCache(ABC): - """Base cache class for a llama.cpp model.""" - - def __init__(self, capacity_bytes: int = (2 << 30)): - self.capacity_bytes = capacity_bytes - - @property - @abstractmethod - def cache_size(self) -> int: - raise NotImplementedError - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - pass - - @abstractmethod - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - raise NotImplementedError - - @abstractmethod - def __contains__(self, key: Sequence[int]) -> bool: - raise NotImplementedError - - @abstractmethod - def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None: - raise NotImplementedError - - -class LlamaRAMCache(BaseLlamaCache): - """Cache for a llama.cpp model using RAM.""" - - def __init__(self, capacity_bytes: int = (2 << 30)): - super().__init__(capacity_bytes) - self.capacity_bytes = capacity_bytes - self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() - - @property - def cache_size(self): - return sum([state.llama_state_size for state in self.cache_state.values()]) - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - min_len = 0 - min_key = None - keys = ( - (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() - ) - for k, prefix_len in keys: - if prefix_len > min_len: - min_len = prefix_len - min_key = k - return min_key - - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - key = tuple(key) - _key = self._find_longest_prefix_key(key) - if _key is None: - raise KeyError("Key not found") - value = self.cache_state[_key] - self.cache_state.move_to_end(_key) - return value - - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None - - def __setitem__(self, key: Sequence[int], value: "LlamaState"): - key = tuple(key) - if key in self.cache_state: - del self.cache_state[key] - self.cache_state[key] = value - while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: - self.cache_state.popitem(last=False) - - -# Alias for backwards compatibility -LlamaCache = LlamaRAMCache - - -class LlamaDiskCache(BaseLlamaCache): - """Cache for a llama.cpp model using disk.""" - - def __init__( - self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) - ): - super().__init__(capacity_bytes) - self.cache = diskcache.Cache(cache_dir) - - @property - def cache_size(self): - return int(self.cache.volume()) # type: ignore - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - min_len = 0 - min_key: Optional[Tuple[int, ...]] = None - for k in self.cache.iterkeys(): # type: ignore - prefix_len = Llama.longest_token_prefix(k, key) - if prefix_len > min_len: - min_len = prefix_len - min_key = k # type: ignore - return min_key - - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - key = tuple(key) - _key = self._find_longest_prefix_key(key) - if _key is None: - raise KeyError("Key not found") - value: "LlamaState" = self.cache.pop(_key) # type: ignore - # NOTE: This puts an integer as key in cache, which breaks, - # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens - # self.cache.push(_key, side="front") # type: ignore - return value - - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None - - def __setitem__(self, key: Sequence[int], value: "LlamaState"): - print("LlamaDiskCache.__setitem__: called", file=sys.stderr) - key = tuple(key) - if key in self.cache: - print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) - del self.cache[key] - self.cache[key] = value - print("LlamaDiskCache.__setitem__: set", file=sys.stderr) - while self.cache_size > self.capacity_bytes and len(self.cache) > 0: - key_to_remove = next(iter(self.cache)) - del self.cache[key_to_remove] - print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) - - class LlamaState: def __init__( self, diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py new file mode 100644 index 0000000..9e9870a --- /dev/null +++ b/llama_cpp/llama_cache.py @@ -0,0 +1,150 @@ +import sys +from abc import ABC, abstractmethod +from typing import ( + Optional, + Sequence, + Tuple, +) +from collections import OrderedDict + +import diskcache + +import llama_cpp.llama + +from .llama_types import * + + +class BaseLlamaCache(ABC): + """Base cache class for a llama.cpp model.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + self.capacity_bytes = capacity_bytes + + @property + @abstractmethod + def cache_size(self) -> int: + raise NotImplementedError + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + pass + + @abstractmethod + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + raise NotImplementedError + + @abstractmethod + def __contains__(self, key: Sequence[int]) -> bool: + raise NotImplementedError + + @abstractmethod + def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None: + raise NotImplementedError + + +class LlamaRAMCache(BaseLlamaCache): + """Cache for a llama.cpp model using RAM.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) + self.capacity_bytes = capacity_bytes + self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict() + + @property + def cache_size(self): + return sum([state.llama_state_size for state in self.cache_state.values()]) + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key = None + keys = ( + (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() + ) + for k, prefix_len in keys: + if prefix_len > min_len: + min_len = prefix_len + min_key = k + return min_key + + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value = self.cache_state[_key] + self.cache_state.move_to_end(_key) + return value + + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + + def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + key = tuple(key) + if key in self.cache_state: + del self.cache_state[key] + self.cache_state[key] = value + while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: + self.cache_state.popitem(last=False) + + +# Alias for backwards compatibility +LlamaCache = LlamaRAMCache + + +class LlamaDiskCache(BaseLlamaCache): + """Cache for a llama.cpp model using disk.""" + + def __init__( + self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) + ): + super().__init__(capacity_bytes) + self.cache = diskcache.Cache(cache_dir) + + @property + def cache_size(self): + return int(self.cache.volume()) # type: ignore + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key: Optional[Tuple[int, ...]] = None + for k in self.cache.iterkeys(): # type: ignore + prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) + if prefix_len > min_len: + min_len = prefix_len + min_key = k # type: ignore + return min_key + + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore + # NOTE: This puts an integer as key in cache, which breaks, + # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens + # self.cache.push(_key, side="front") # type: ignore + return value + + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + + def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + print("LlamaDiskCache.__setitem__: called", file=sys.stderr) + key = tuple(key) + if key in self.cache: + print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) + del self.cache[key] + self.cache[key] = value + print("LlamaDiskCache.__setitem__: set", file=sys.stderr) + while self.cache_size > self.capacity_bytes and len(self.cache) > 0: + key_to_remove = next(iter(self.cache)) + del self.cache[key_to_remove] + print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) From cc4630e66f2a7f12c3287aece85779ab499c1c9d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Jan 2024 09:14:00 -0500 Subject: [PATCH 02/26] Move helper classes to _internals submodule --- llama_cpp/_internals.py | 770 ++++++++++++++++++++++++++++++++++++++++ llama_cpp/llama.py | 518 +-------------------------- 2 files changed, 776 insertions(+), 512 deletions(-) create mode 100644 llama_cpp/_internals.py diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py new file mode 100644 index 0000000..208de8c --- /dev/null +++ b/llama_cpp/_internals.py @@ -0,0 +1,770 @@ +from __future__ import annotations + +import os +import ctypes + +from typing import ( + List, + Optional, + Sequence, +) +from dataclasses import dataclass, field + +import numpy as np +import numpy.typing as npt + +from .llama_types import * +from .llama_grammar import LlamaGrammar + +import llama_cpp.llama_cpp as llama_cpp + +from ._utils import suppress_stdout_stderr + + +# Python wrappers over llama.h structs + + +class _LlamaModel: + """Intermediate Python wrapper for a llama.cpp llama_model. + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free_model = None + # NOTE: this must be "saved" here to avoid exceptions when calling __del__ + _suppress_stdout_stderr = suppress_stdout_stderr + + def __init__( + self, + *, + path_model: str, + params: llama_cpp.llama_model_params, + verbose: bool = True, + ): + self.path_model = path_model + self.params = params + self.verbose = verbose + + self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore + + if not os.path.exists(path_model): + raise ValueError(f"Model path does not exist: {path_model}") + + with self._suppress_stdout_stderr(disable=self.verbose): + self.model = llama_cpp.llama_load_model_from_file( + self.path_model.encode("utf-8"), self.params + ) + + def __del__(self): + with self._suppress_stdout_stderr(disable=self.verbose): + if self.model is not None and self._llama_free_model is not None: + self._llama_free_model(self.model) + self.model = None + + def vocab_type(self) -> int: + assert self.model is not None + return llama_cpp.llama_vocab_type(self.model) + + def n_vocab(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_vocab(self.model) + + def n_ctx_train(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_ctx_train(self.model) + + def n_embd(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_embd(self.model) + + def rope_freq_scale_train(self) -> float: + assert self.model is not None + return llama_cpp.llama_rope_freq_scale_train(self.model) + + def desc(self) -> str: + assert self.model is not None + buf = ctypes.create_string_buffer(1024) + llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore + return buf.value.decode("utf-8") + + def size(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_size(self.model) + + def n_params(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_n_params(self.model) + + def get_tensor(self, name: str) -> ctypes.c_void_p: + assert self.model is not None + return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) + + def apply_lora_from_file( + self, + lora_path: str, + scale: float, + path_base_model: Optional[str], + n_threads: int, + ): + assert self.model is not None + return llama_cpp.llama_model_apply_lora_from_file( + self.model, + lora_path.encode("utf-8"), + scale, + path_base_model.encode("utf-8") + if path_base_model is not None + else llama_cpp.c_char_p(0), + n_threads, + ) + + # Vocab + + def token_get_text(self, token: int) -> str: + # TODO: Fix + assert self.model is not None + return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") + + def token_get_score(self, token: int) -> float: + assert self.model is not None + return llama_cpp.llama_token_get_score(self.model, token) + + def token_get_type(self, token: int) -> int: + assert self.model is not None + return llama_cpp.llama_token_get_type(self.model, token) + + # Special tokens + + def token_bos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_bos(self.model) + + def token_eos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eos(self.model) + + def token_nl(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_nl(self.model) + + def token_prefix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_prefix(self.model) + + def token_middle(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_middle(self.model) + + def token_suffix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_suffix(self.model) + + def token_eot(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eot(self.model) + + # Tokenization + + def tokenize(self, text: bytes, add_bos: bool, special: bool): + assert self.model is not None + n_ctx = self.n_ctx_train() + tokens = (llama_cpp.llama_token * n_ctx)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_ctx, add_bos, special + ) + if n_tokens < 0: + n_tokens = abs(n_tokens) + tokens = (llama_cpp.llama_token * n_tokens)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_tokens, add_bos, special + ) + if n_tokens < 0: + raise RuntimeError( + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + ) + return list(tokens[:n_tokens]) + + def token_to_piece(self, token: int) -> bytes: + assert self.model is not None + buf = ctypes.create_string_buffer(32) + llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore + return bytes(buf) + + def detokenize(self, tokens: List[int]) -> bytes: + assert self.model is not None + output = b"" + size = 32 + buffer = (ctypes.c_char * size)() + for token in tokens: + n = llama_cpp.llama_token_to_piece( + self.model, llama_cpp.llama_token(token), buffer, size + ) + assert n <= size + output += bytes(buffer[:n]) + # NOTE: Llama1 models automatically added a space at the start of the prompt + # this line removes a leading space if the first token is a beginning of sentence token + return ( + output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + ) + + @staticmethod + def default_params(): + """Get the default llama_model_params.""" + return llama_cpp.llama_model_default_params() + + +class _LlamaContext: + """Intermediate Python wrapper for a llama.cpp llama_context. + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free = None + # NOTE: this must be "saved" here to avoid exceptions when calling __del__ + _suppress_stdout_stderr = suppress_stdout_stderr + + def __init__( + self, + *, + model: _LlamaModel, + params: llama_cpp.llama_context_params, + verbose: bool = True, + ): + self.model = model + self.params = params + self.verbose = verbose + + self._llama_free = llama_cpp._lib.llama_free # type: ignore + + with self._suppress_stdout_stderr(disable=self.verbose): + self.ctx = llama_cpp.llama_new_context_with_model( + self.model.model, self.params + ) + + def __del__(self): + with self._suppress_stdout_stderr(disable=self.verbose): + if self.ctx is not None and self._llama_free is not None: + self._llama_free(self.ctx) + self.ctx = None + + def n_ctx(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_n_ctx(self.ctx) + + def kv_cache_clear(self): + assert self.ctx is not None + llama_cpp.llama_kv_cache_clear(self.ctx) + + def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) + + def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) + + def kv_cache_seq_keep(self, seq_id: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) + + def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) + + def get_state_size(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_get_state_size(self.ctx) + + # TODO: copy_state_data + + # TODO: set_state_data + + # TODO: llama_load_session_file + + # TODO: llama_save_session_file + + def decode(self, batch: "_LlamaBatch"): + assert self.ctx is not None + assert batch.batch is not None + return_code = llama_cpp.llama_decode( + ctx=self.ctx, + batch=batch.batch, + ) + if return_code != 0: + raise RuntimeError(f"llama_decode returned {return_code}") + + def set_n_threads(self, n_threads: int, n_threads_batch: int): + assert self.ctx is not None + llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) + + def get_logits(self): + assert self.ctx is not None + return llama_cpp.llama_get_logits(self.ctx) + + def get_logits_ith(self, i: int): + assert self.ctx is not None + return llama_cpp.llama_get_logits_ith(self.ctx, i) + + def get_embeddings(self): + assert self.ctx is not None + return llama_cpp.llama_get_embeddings(self.ctx) + + # Sampling functions + + def set_rng_seed(self, seed: int): + assert self.ctx is not None + llama_cpp.llama_set_rng_seed(self.ctx, seed) + + def sample_repetition_penalties( + self, + candidates: "_LlamaTokenDataArray", + last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", + penalty_last_n: int, + penalty_repeat: float, + penalty_freq: float, + penalty_present: float, + ): + assert self.ctx is not None + llama_cpp.llama_sample_repetition_penalties( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + last_tokens_data, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + ) + + def sample_classifier_free_guidance( + self, + candidates: "_LlamaTokenDataArray", + guidance_ctx: "_LlamaContext", + scale: float, + ): + assert self.ctx is not None + assert guidance_ctx.ctx is not None + llama_cpp.llama_sample_classifier_free_guidance( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + guidance_ctx.ctx, + scale, + ) + + def sample_softmax(self, candidates: "_LlamaTokenDataArray"): + assert self.ctx is not None + llama_cpp.llama_sample_softmax( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_k( + self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore + ) + + def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_min_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_tail_free( + self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_tail_free( + self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore + ) + + def sample_typical( + self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_typical( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): + assert self.ctx is not None + llama_cpp.llama_sample_temp( + self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore + ) + + def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_sample_grammar( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + grammar.grammar, + ) + + def sample_token_mirostat( + self, + candidates: "_LlamaTokenDataArray", + tau: float, + eta: float, + m: int, + mu: ctypes._Pointer[ctypes.c_float], # type: ignore + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + m, + mu, + ) + + def sample_token_mirostat_v2( + self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat_v2( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + mu, + ) + + def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_greedy( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + # Grammar + def grammar_accept_token(self, grammar: LlamaGrammar, token: int): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) + + def reset_timings(self): + assert self.ctx is not None + llama_cpp.llama_reset_timings(self.ctx) + + def print_timings(self): + assert self.ctx is not None + llama_cpp.llama_print_timings(self.ctx) + + # Utility functions + @staticmethod + def default_params(): + """Get the default llama_context_params.""" + return llama_cpp.llama_context_default_params() + + +class _LlamaBatch: + _llama_batch_free = None + # NOTE: this must be "saved" here to avoid exceptions when calling __del__ + _suppress_stdout_stderr = suppress_stdout_stderr + + def __init__( + self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True + ): + self.n_tokens = n_tokens + self.embd = embd + self.n_seq_max = n_seq_max + self.verbose = verbose + + self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore + + with self._suppress_stdout_stderr(disable=self.verbose): + self.batch = llama_cpp.llama_batch_init( + self.n_tokens, self.embd, self.n_seq_max + ) + + def __del__(self): + with self._suppress_stdout_stderr(disable=self.verbose): + if self.batch is not None and self._llama_batch_free is not None: + self._llama_batch_free(self.batch) + self.batch = None + + def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + self.batch.n_tokens = n_tokens + for i in range(n_tokens): + self.batch.token[i] = batch[i] + self.batch.pos[i] = n_past + i + self.batch.seq_id[i][0] = 0 + self.batch.n_seq_id[i] = 1 + self.batch.logits[i] = logits_all + self.batch.logits[n_tokens - 1] = True + + +class _LlamaTokenDataArray: + def __init__(self, *, n_vocab: int): + self.n_vocab = n_vocab + self.candidates_data = np.array( + [], + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + ), + ) + self.candidates_data.resize(3, self.n_vocab, refcheck=False) + self.candidates = llama_cpp.llama_token_data_array( + data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), + size=self.n_vocab, + sorted=False, + ) + self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) + self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) + + def copy_logits(self, logits: npt.NDArray[np.single]): + self.candidates_data["id"][:] = self.default_candidates_data_id + self.candidates_data["logit"][:] = logits + self.candidates_data["p"][:] = self.default_candidates_data_p + self.candidates.data = self.candidates_data.ctypes.data_as( + llama_cpp.llama_token_data_p + ) + self.candidates.sorted = llama_cpp.c_bool(False) + self.candidates.size = llama_cpp.c_size_t(self.n_vocab) + + +# Python wrappers over common/common +def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: + n_tokens = len(text) + 1 if add_bos else len(text) + result = (llama_cpp.llama_token * n_tokens)() + n_tokens = llama_cpp.llama_tokenize( + model.model, + text.encode("utf-8"), + len(text), + result, + n_tokens, + add_bos, + special, + ) + if n_tokens < 0: + result = (llama_cpp.llama_token * -n_tokens)() + check = llama_cpp.llama_tokenize( + model.model, + text.encode("utf-8"), + len(text), + result, + len(result), + add_bos, + special, + ) + if check != -n_tokens: + raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') + else: + result = result[:n_tokens] + return list(result) + + +def _token_to_piece(model: _LlamaModel, token: int) -> str: + assert model.model is not None + result = (ctypes.c_char * 8)(0) + n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + if n_tokens < 0: + result = (ctypes.c_char * -n_tokens)(0) + check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + if check != -n_tokens: + raise RuntimeError(f"Failed to get piece: token={token}") + else: + result = result[:n_tokens] + return bytes(result).decode("utf-8") + + +def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: + bos_id = model.token_bos() + result = "" + for i, token in enumerate(tokens): + piece = _token_to_piece(model, token) + if ( + (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) + ) and piece[0] == " ": + piece = piece[1:] + result += piece + return result + + +def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: + result = "" + for token in tokens: + piece = _token_to_piece(model, token) + result += piece + return result + + +def _should_add_bos(model: _LlamaModel) -> bool: + assert model.model is not None + add_bos = llama_cpp.llama_add_bos_token(model.model) + if add_bos != -1: + return add_bos != 0 + else: + return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM + + +# Python wrappers over common/sampling structs + + +@dataclass +class _LlamaSamplingParams: + n_prev: int = 64 + n_probs: int = 0 + top_k: int = 40 + top_p: float = 0.95 + min_p: float = 0.05 + tfs_z: float = 1.00 + typical_p: float = 1.00 + temp: float = 0.80 + penalty_last_n: int = 64 + penalty_repeat: float = 1.10 + penalty_freq: float = 0.00 + penalty_present: float = 0.00 + mirostat: int = 0 + mirostat_tau: float = 5.00 + mirostat_eta: float = 0.10 + penalize_nl: bool = True + + grammar: str = "" + + cfg_negative_prompt: str = "" + cfg_scale: float = 1.00 + + logit_bias: dict[int, float] = field(default_factory=dict) + + +@dataclass +class _LlamaSamplingContext: + params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) + mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) + grammar: Optional[LlamaGrammar] = None + # NOTE: Missing parsed_grammar + prev: list[int] = field(default_factory=list) + cur: list[llama_cpp.llama_token_data] = field(default_factory=list) + + def reset(self): + self.prev = [] + self.cur = [] + if self.grammar is not None: + self.grammar.reset() + + def cp(self): + return _LlamaSamplingContext( + params=self.params, + mirostat_mu=self.mirostat_mu, + grammar=self.grammar, + prev=self.prev.copy(), + cur=self.cur.copy(), + ) + + def last(self) -> Optional[int]: + if len(self.prev) > 0: + return self.prev[-1] + else: + return None + + def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: + return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") + + def sample( + self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None + ): + n_vocab = ctx_main.model.n_vocab() + id: int = 0 + + if logits_array is None: + logits = ctx_main.get_logits_ith(idx) + logits_array = np.array( + ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, + dtype=np.single, + ) + + # apply logit_bias + for token, logit_bias in self.params.logit_bias.items(): + logits_array[token] += logit_bias + + token_data_array = _LlamaTokenDataArray( + n_vocab=n_vocab + ) # TODO: Only create this once + token_data_array.copy_logits(logits_array) + + if ctx_cfg is not None: + ctx_main.sample_classifier_free_guidance( + token_data_array, ctx_cfg, self.params.cfg_scale + ) + + # apply penalties + if len(self.prev) > 0: + nl_token = ctx_main.model.token_nl() + nl_logit = logits_array[nl_token] + if self.params.penalty_last_n > 0: + ctx_main.sample_repetition_penalties( + token_data_array, + # TODO: Only create this once + (llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore + self.params.penalty_last_n, + self.params.penalty_repeat, + self.params.penalty_freq, + self.params.penalty_present, + ) + if not self.params.penalize_nl: + token_data_array.candidates_data["logit"][nl_token] = nl_logit + + if self.grammar is not None: + ctx_main.sample_grammar(token_data_array, self.grammar) + + if self.params.temp < 0: + ctx_main.sample_softmax(token_data_array) + id = token_data_array.candidates_data["id"][0] + elif self.params.temp == 0: + id = ctx_main.sample_token_greedy(token_data_array) + else: + if self.params.mirostat == 1: + mirostat_m = 100 + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + mirostat_m, + ctypes.pointer(self.mirostat_mu), + ) + elif self.params.mirostat == 2: + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat_v2( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + ctypes.pointer(self.mirostat_mu), + ) + else: + min_keep = max(1, self.params.n_probs) + ctx_main.sample_top_k( + token_data_array, self.params.top_k, min_keep=min_keep + ) + ctx_main.sample_tail_free( + token_data_array, self.params.tfs_z, min_keep=min_keep + ) + ctx_main.sample_typical( + token_data_array, self.params.typical_p, min_keep=min_keep + ) + ctx_main.sample_top_p( + token_data_array, self.params.top_p, min_keep=min_keep + ) + ctx_main.sample_min_p( + token_data_array, self.params.min_p, min_keep=min_keep + ) + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token(token_data_array) + return id + + def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): + if apply_grammar and self.grammar is not None: + ctx_main.grammar_accept_token(self.grammar, id) + self.prev.append(id) \ No newline at end of file diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c6b55ab..f4e5dcd 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -32,6 +32,12 @@ import numpy as np import numpy.typing as npt from ._utils import suppress_stdout_stderr +from ._internals import ( + _LlamaModel, # type: ignore + _LlamaContext, # type: ignore + _LlamaBatch, # type: ignore + _LlamaTokenDataArray, # type: ignore +) class LlamaState: @@ -74,518 +80,6 @@ class StoppingCriteriaList(List[StoppingCriteria]): return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) -class _LlamaModel: - """Intermediate Python wrapper for a llama.cpp llama_model. - - NOTE: For stability it's recommended you use the Llama class instead.""" - - _llama_free_model = None - # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - suppress_stdout_stderr = suppress_stdout_stderr - - def __init__( - self, - *, - path_model: str, - params: llama_cpp.llama_model_params, - verbose: bool = True, - ): - self.path_model = path_model - self.params = params - self.verbose = verbose - - self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore - - if not os.path.exists(path_model): - raise ValueError(f"Model path does not exist: {path_model}") - - with suppress_stdout_stderr(disable=self.verbose): - self.model = llama_cpp.llama_load_model_from_file( - self.path_model.encode("utf-8"), self.params - ) - - def __del__(self): - with self.suppress_stdout_stderr(disable=self.verbose): - if self.model is not None and self._llama_free_model is not None: - self._llama_free_model(self.model) - self.model = None - - def vocab_type(self) -> int: - assert self.model is not None - return llama_cpp.llama_vocab_type(self.model) - - def n_vocab(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_vocab(self.model) - - def n_ctx_train(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_ctx_train(self.model) - - def n_embd(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_embd(self.model) - - def rope_freq_scale_train(self) -> float: - assert self.model is not None - return llama_cpp.llama_rope_freq_scale_train(self.model) - - def desc(self) -> str: - assert self.model is not None - buf = ctypes.create_string_buffer(1024) - llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore - return buf.value.decode("utf-8") - - def size(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_size(self.model) - - def n_params(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_n_params(self.model) - - def get_tensor(self, name: str) -> ctypes.c_void_p: - assert self.model is not None - return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) - - def apply_lora_from_file( - self, - lora_path: str, - scale: float, - path_base_model: Optional[str], - n_threads: int, - ): - assert self.model is not None - return llama_cpp.llama_model_apply_lora_from_file( - self.model, - lora_path.encode("utf-8"), - scale, - path_base_model.encode("utf-8") - if path_base_model is not None - else llama_cpp.c_char_p(0), - n_threads, - ) - - # Vocab - - def token_get_text(self, token: int) -> str: - # TODO: Fix - assert self.model is not None - return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") - - def token_get_score(self, token: int) -> float: - assert self.model is not None - return llama_cpp.llama_token_get_score(self.model, token) - - def token_get_type(self, token: int) -> int: - assert self.model is not None - return llama_cpp.llama_token_get_type(self.model, token) - - # Special tokens - - def token_bos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_bos(self.model) - - def token_eos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eos(self.model) - - def token_nl(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_nl(self.model) - - def token_prefix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_prefix(self.model) - - def token_middle(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_middle(self.model) - - def token_suffix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_suffix(self.model) - - def token_eot(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eot(self.model) - - # Tokenization - - def tokenize(self, text: bytes, add_bos: bool, special: bool): - assert self.model is not None - n_ctx = self.n_ctx_train() - tokens = (llama_cpp.llama_token * n_ctx)() - n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_ctx, add_bos, special - ) - if n_tokens < 0: - n_tokens = abs(n_tokens) - tokens = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_tokens, add_bos, special - ) - if n_tokens < 0: - raise RuntimeError( - f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' - ) - return list(tokens[:n_tokens]) - - def token_to_piece(self, token: int) -> bytes: - assert self.model is not None - buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore - return bytes(buf) - - def detokenize(self, tokens: List[int]) -> bytes: - assert self.model is not None - output = b"" - size = 32 - buffer = (ctypes.c_char * size)() - for token in tokens: - n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size - ) - assert n <= size - output += bytes(buffer[:n]) - # NOTE: Llama1 models automatically added a space at the start of the prompt - # this line removes a leading space if the first token is a beginning of sentence token - return ( - output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output - ) - - @staticmethod - def default_params(): - """Get the default llama_model_params.""" - return llama_cpp.llama_model_default_params() - - -class _LlamaContext: - """Intermediate Python wrapper for a llama.cpp llama_context. - - NOTE: For stability it's recommended you use the Llama class instead.""" - - _llama_free = None - # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - suppress_stdout_stderr = suppress_stdout_stderr - - def __init__( - self, - *, - model: _LlamaModel, - params: llama_cpp.llama_context_params, - verbose: bool = True, - ): - self.model = model - self.params = params - self.verbose = verbose - - self._llama_free = llama_cpp._lib.llama_free # type: ignore - - with suppress_stdout_stderr(disable=self.verbose): - self.ctx = llama_cpp.llama_new_context_with_model( - self.model.model, self.params - ) - - def __del__(self): - with self.suppress_stdout_stderr(disable=self.verbose): - if self.ctx is not None and self._llama_free is not None: - self._llama_free(self.ctx) - self.ctx = None - - def n_ctx(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_n_ctx(self.ctx) - - def kv_cache_clear(self): - assert self.ctx is not None - llama_cpp.llama_kv_cache_clear(self.ctx) - - def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) - - def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) - - def kv_cache_seq_keep(self, seq_id: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) - - def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) - - def get_state_size(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_get_state_size(self.ctx) - - # TODO: copy_state_data - - # TODO: set_state_data - - # TODO: llama_load_session_file - - # TODO: llama_save_session_file - - def decode(self, batch: "_LlamaBatch"): - assert self.ctx is not None - assert batch.batch is not None - return_code = llama_cpp.llama_decode( - ctx=self.ctx, - batch=batch.batch, - ) - if return_code != 0: - raise RuntimeError(f"llama_decode returned {return_code}") - - def set_n_threads(self, n_threads: int, n_threads_batch: int): - assert self.ctx is not None - llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) - - def get_logits(self): - assert self.ctx is not None - return llama_cpp.llama_get_logits(self.ctx) - - def get_logits_ith(self, i: int): - assert self.ctx is not None - return llama_cpp.llama_get_logits_ith(self.ctx, i) - - def get_embeddings(self): - assert self.ctx is not None - return llama_cpp.llama_get_embeddings(self.ctx) - - # Sampling functions - - def set_rng_seed(self, seed: int): - assert self.ctx is not None - llama_cpp.llama_set_rng_seed(self.ctx, seed) - - def sample_repetition_penalties( - self, - candidates: "_LlamaTokenDataArray", - last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", - penalty_last_n: int, - penalty_repeat: float, - penalty_freq: float, - penalty_present: float, - ): - assert self.ctx is not None - llama_cpp.llama_sample_repetition_penalties( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - last_tokens_data, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ) - - def sample_classifier_free_guidance( - self, - candidates: "_LlamaTokenDataArray", - guidance_ctx: "_LlamaContext", - scale: float, - ): - assert self.ctx is not None - assert guidance_ctx.ctx is not None - llama_cpp.llama_sample_classifier_free_guidance( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - guidance_ctx.ctx, - scale, - ) - - def sample_softmax(self, candidates: "_LlamaTokenDataArray"): - assert self.ctx is not None - llama_cpp.llama_sample_softmax( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_k( - self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore - ) - - def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_p( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_min_p( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_tail_free( - self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int - ): - assert self.ctx is not None - llama_cpp.llama_sample_tail_free( - self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore - ) - - def sample_typical( - self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int - ): - assert self.ctx is not None - llama_cpp.llama_sample_typical( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): - assert self.ctx is not None - llama_cpp.llama_sample_temp( - self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore - ) - - def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_sample_grammar( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - grammar.grammar, - ) - - def sample_token_mirostat( - self, - candidates: "_LlamaTokenDataArray", - tau: float, - eta: float, - m: int, - mu: float, - ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - tau, - eta, - m, - ctypes.pointer(ctypes.c_float(mu)), - ) - - def sample_token_mirostat_v2( - self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float - ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat_v2( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - tau, - eta, - ctypes.pointer(ctypes.c_float(mu)), - ) - - def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_greedy( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - # Grammar - def grammar_accept_token(self, grammar: LlamaGrammar, token: int): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) - - def reset_timings(self): - assert self.ctx is not None - llama_cpp.llama_reset_timings(self.ctx) - - def print_timings(self): - assert self.ctx is not None - llama_cpp.llama_print_timings(self.ctx) - - # Utility functions - @staticmethod - def default_params(): - """Get the default llama_context_params.""" - return llama_cpp.llama_context_default_params() - - -class _LlamaBatch: - _llama_batch_free = None - # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - suppress_stdout_stderr = suppress_stdout_stderr - - def __init__( - self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True - ): - self.n_tokens = n_tokens - self.embd = embd - self.n_seq_max = n_seq_max - self.verbose = verbose - - self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore - - with suppress_stdout_stderr(disable=self.verbose): - self.batch = llama_cpp.llama_batch_init( - self.n_tokens, self.embd, self.n_seq_max - ) - - def __del__(self): - with self.suppress_stdout_stderr(disable=self.verbose): - if self.batch is not None and self._llama_batch_free is not None: - self._llama_batch_free(self.batch) - self.batch = None - - def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): - assert self.batch is not None - n_tokens = len(batch) - self.batch.n_tokens = n_tokens - for i in range(n_tokens): - self.batch.token[i] = batch[i] - self.batch.pos[i] = n_past + i - self.batch.seq_id[i][0] = 0 - self.batch.n_seq_id[i] = 1 - self.batch.logits[i] = logits_all - self.batch.logits[n_tokens - 1] = True - - -class _LlamaTokenDataArray: - def __init__(self, *, n_vocab: int): - self.n_vocab = n_vocab - self.candidates_data = np.array( - [], - dtype=np.dtype( - [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True - ), - ) - self.candidates_data.resize(3, self.n_vocab, refcheck=False) - self.candidates = llama_cpp.llama_token_data_array( - data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), - size=self.n_vocab, - sorted=False, - ) - self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) - self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) - - def copy_logits(self, logits: npt.NDArray[np.single]): - self.candidates_data["id"][:] = self.default_candidates_data_id - self.candidates_data["logit"][:] = logits - self.candidates_data["p"][:] = self.default_candidates_data_p - self.candidates.data = self.candidates_data.ctypes.data_as( - llama_cpp.llama_token_data_p - ) - self.candidates.sorted = llama_cpp.c_bool(False) - self.candidates.size = llama_cpp.c_size_t(self.n_vocab) - - class Llama: """High-level Python wrapper for a llama.cpp model.""" From 7b46bb5a786a0460d06adc053d3e494e92b75b39 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Jan 2024 09:16:13 -0500 Subject: [PATCH 03/26] Re-order classes in llama.py --- llama_cpp/llama.py | 82 ++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f4e5dcd..25abf36 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys import uuid @@ -40,46 +42,6 @@ from ._internals import ( ) -class LlamaState: - def __init__( - self, - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - n_tokens: int, - llama_state: bytes, - llama_state_size: int, - ): - self.input_ids = input_ids - self.scores = scores - self.n_tokens = n_tokens - self.llama_state = llama_state - self.llama_state_size = llama_state_size - - -LogitsProcessor = Callable[ - [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] -] - - -class LogitsProcessorList(List[LogitsProcessor]): - def __call__( - self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] - ) -> npt.NDArray[np.single]: - for processor in self: - scores = processor(input_ids, scores) - return scores - - -StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] - - -class StoppingCriteriaList(List[StoppingCriteria]): - def __call__( - self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] - ) -> bool: - return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) - - class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -1733,3 +1695,43 @@ class LlamaTokenizer: @classmethod def from_ggml_file(cls, path: str) -> "LlamaTokenizer": return cls(Llama(model_path=path, vocab_only=True)) + + +class LlamaState: + def __init__( + self, + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + n_tokens: int, + llama_state: bytes, + llama_state_size: int, + ): + self.input_ids = input_ids + self.scores = scores + self.n_tokens = n_tokens + self.llama_state = llama_state + self.llama_state_size = llama_state_size + + +LogitsProcessor = Callable[ + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] +] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__( + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] + ) -> npt.NDArray[np.single]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__( + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] + ) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) From 52adc231153624d7a1c949c97c5a0c5b3e4eabf7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Jan 2024 09:27:40 -0500 Subject: [PATCH 04/26] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 5c99960..4f4bf35 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 5c999609013a30c06e6fd28be8db5c2074bcc196 +Subproject commit 4f4bf35f46600441dec2f941e667291eeb9a18d8 From 6bfe98bd801262b68c7cb4761e67a626d91534c0 Mon Sep 17 00:00:00 2001 From: Austin <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:47:52 -0500 Subject: [PATCH 05/26] Integration of Jinja2 Templating (#875) * feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- docs/templates.md | 52 ++++++++++++ llama_cpp/llama_jinja_format.py | 138 ++++++++++++++++++++++++++++++++ pyproject.toml | 4 +- tests/test_llama_chat_format.py | 50 ++++++++++++ 4 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 docs/templates.md create mode 100644 llama_cpp/llama_jinja_format.py create mode 100644 tests/test_llama_chat_format.py diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000..5acdaa1 --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,52 @@ +# Templates + +This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model. + +## Introduction + +- Brief explanation of the `llama-cpp-python` project's need for a templating system. +- Overview of the `llama-2` model's interaction with templating. + +## Jinja2 Dependency Integration + +- Rationale for choosing Jinja2 as the templating engine. + - Compatibility with Hugging Face's `transformers`. + - Desire for advanced templating features and simplicity. +- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management. + +## Template Management Refactor + +- Summary of the refactor and the motivation behind it. +- Description of the new chat handler selection logic: + 1. Preference for a user-specified `chat_handler`. + 2. Fallback to a user-specified `chat_format`. + 3. Defaulting to a chat format from a `.gguf` file if available. + 4. Utilizing the `llama2` default chat format as the final fallback. +- Ensuring backward compatibility throughout the refactor. + +## Implementation Details + +- In-depth look at the new `AutoChatFormatter` class. +- Example code snippets showing how to utilize the Jinja2 environment and templates. +- Guidance on how to provide custom templates or use defaults. + +## Testing and Validation + +- Outline of the testing strategy to ensure seamless integration. +- Steps for validating backward compatibility with existing implementations. + +## Benefits and Impact + +- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience. +- Discussion of the potential impact on current users and contributors. + +## Future Work + +- Exploration of how templating can evolve within the project. +- Consideration of additional features or optimizations for the templating engine. +- Mechanisms for community feedback on the templating system. + +## Conclusion + +- Final thoughts on the integration of Jinja2 templating. +- Call to action for community involvement and feedback. diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py new file mode 100644 index 0000000..68faaf6 --- /dev/null +++ b/llama_cpp/llama_jinja_format.py @@ -0,0 +1,138 @@ +""" +llama_cpp/llama_jinja_format.py +""" +import dataclasses +from typing import Any, Callable, Dict, List, Optional, Protocol, Union + +import jinja2 +from jinja2 import Template + +# NOTE: We sacrifice readability for usability. +# It will fail to work as expected if we attempt to format it in a readable way. +llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" + + +class MetaSingleton(type): + """ + Metaclass for implementing the Singleton pattern. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Singleton(object, metaclass=MetaSingleton): + """ + Base class for implementing the Singleton pattern. + """ + + def __init__(self): + super(Singleton, self).__init__() + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +# Base Chat Formatter Protocol +class ChatFormatterInterface(Protocol): + def __init__(self, template: Optional[object] = None): + ... + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs, + ) -> ChatFormatterResponse: + ... + + @property + def template(self) -> str: + ... + + +class AutoChatFormatter(ChatFormatterInterface): + def __init__( + self, + template: Optional[str] = None, + template_class: Optional[Template] = None, + ): + if template is not None: + self._template = template + else: + self._template = llama2_template # default template + + self._environment = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string( + self._template, + template_class=template_class, + ) + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs: Any, + ) -> ChatFormatterResponse: + formatted_sequence = self._environment.render(messages=messages, **kwargs) + return ChatFormatterResponse(prompt=formatted_sequence) + + @property + def template(self) -> str: + return self._template + + +class FormatterNotFoundException(Exception): + pass + + +class ChatFormatterFactory(Singleton): + _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} + + def register_formatter( + self, + name: str, + formatter_callable: Callable[[], ChatFormatterInterface], + overwrite=False, + ): + if not overwrite and name in self._chat_formatters: + raise ValueError( + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + ) + self._chat_formatters[name] = formatter_callable + + def unregister_formatter(self, name: str): + if name in self._chat_formatters: + del self._chat_formatters[name] + else: + raise ValueError(f"No formatter registered under the name '{name}'.") + + def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: + try: + formatter_callable = self._chat_formatters[name] + return formatter_callable() + except KeyError: + raise FormatterNotFoundException( + f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" + ) + + +# Define a chat format class +class Llama2Formatter(AutoChatFormatter): + def __init__(self): + super().__init__(llama2_template) + + +# With the Singleton pattern applied, regardless of where or how many times +# ChatFormatterFactory() is called, it will always return the same instance +# of the factory, ensuring that the factory's state is consistent throughout +# the application. +ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) diff --git a/pyproject.toml b/pyproject.toml index b5affaa..806127d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,13 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] +# mkdocs-martiral requires "jinja2~=3.0" +# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1", + "jinja2>=2.11.3", ] requires-python = ">=3.8" classifiers = [ @@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/" [tool.pytest.ini_options] addopts = "--ignore=vendor" - diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py new file mode 100644 index 0000000..4eebcb6 --- /dev/null +++ b/tests/test_llama_chat_format.py @@ -0,0 +1,50 @@ +from typing import List + +import pytest + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_jinja_format import Llama2Formatter + + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] + + +def test_llama2_formatter(sequence_of_messages): + expected_prompt = ( + "<> Welcome to CodeHelp Bot! <>\n" + "[INST] Hi there! I need some help with Python. [/INST]\n" + "Of course! What do you need help with in Python?\n" + "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" + "I can help with that! Would you like a recursive or iterative solution?\n" + "[INST] Let's go with a recursive solution. [/INST]\n" + ) + + llama2_formatter_instance = Llama2Formatter() + formatter_response = llama2_formatter_instance(sequence_of_messages) + assert ( + expected_prompt == formatter_response.prompt + ), "The formatted prompt does not match the expected output." + + +# Optionally, include a test for the 'stop' if it's part of the functionality. From 48c3b77e6f558a9899de0e1155c7dc0c7958d8e8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 11:08:57 -0500 Subject: [PATCH 06/26] Offload KQV by default --- llama_cpp/llama.py | 2 +- llama_cpp/server/settings.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 25abf36..6cdc1eb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -77,7 +77,7 @@ class Llama: mul_mat_q: bool = True, logits_all: bool = False, embedding: bool = False, - offload_kqv: bool = False, + offload_kqv: bool = True, # Sampling Params last_n_tokens_size: int = 64, # LoRA Params diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index a10390c..dc5be20 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -90,7 +90,7 @@ class ModelSettings(BaseSettings): logits_all: bool = Field(default=True, description="Whether to return logits.") embedding: bool = Field(default=True, description="Whether to use embeddings.") offload_kqv: bool = Field( - default=False, description="Whether to offload kqv to the GPU." + default=True, description="Whether to offload kqv to the GPU." ) # Sampling Params last_n_tokens_size: int = Field( From b8fc1c7d83ad4a9207c707ba1d954fe580286a01 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:21:37 -0500 Subject: [PATCH 07/26] feat: Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files. --- llama_cpp/_utils.py | 25 ++- llama_cpp/llama_chat_format.py | 379 +++++++++++++++++++++----------- llama_cpp/llama_jinja_format.py | 138 ------------ llama_cpp/server/model.py | 23 +- llama_cpp/server/settings.py | 9 + tests/test_llama_chat_format.py | 101 +++++---- 6 files changed, 357 insertions(+), 318 deletions(-) delete mode 100644 llama_cpp/llama_jinja_format.py diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index f7b6ba6..4a10647 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -1,7 +1,8 @@ import os import sys -import sys, traceback +import sys +from typing import Any, Dict # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor outnull_file = open(os.devnull, "w") @@ -55,3 +56,25 @@ class suppress_stdout_stderr(object): self.os.close(self.old_stdout_fileno) self.os.close(self.old_stderr_fileno) + + +class MetaSingleton(type): + """ + Metaclass for implementing the Singleton pattern. + """ + + _instances: Dict[type, Any] = {} + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + if cls not in cls._instances: + cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Singleton(object, metaclass=MetaSingleton): + """ + Base class for implementing the Singleton pattern. + """ + + def __init__(self): + super(Singleton, self).__init__() diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0ef7bd4..3d18d90 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -6,18 +6,28 @@ import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol +import jinja2 + import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar -from ._utils import suppress_stdout_stderr +from ._utils import suppress_stdout_stderr, Singleton class LlamaChatCompletionHandler(Protocol): + """Base Protocol for a llama chat completion handler. + + Very generic protocol that can be used to implement any chat format. + The only hard requirement is that it must return a ChatCompletion when + stream=False and an iterator of ChatCompletionChunks when stream=True.""" + def __call__( self, *, + # llama.cpp instance llama: llama.Llama, + # openai api parameters messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, @@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol): temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, - min_p: float = 0.05, - typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol): presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, + model: Optional[str] = None, + logit_bias: Optional[Dict[str, float]] = None, + # llama.cpp parameters + min_p: float = 0.05, + typical_p: float = 1.0, tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol): ... -CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {} +class LlamaChatCompletionHandlerNotFoundException(Exception): + pass + + +class LlamaChatCompletionHandlerRegistry(Singleton): + _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {} + + def register_chat_completion_handler( + self, + name: str, + chat_handler: LlamaChatCompletionHandler, + overwrite: bool = False, + ): + if not overwrite and name in self._chat_handlers: + raise ValueError( + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + ) + self._chat_handlers[name] = chat_handler + + def unregister_chat_handler(self, name: str): + if name in self._chat_handlers: + del self._chat_handlers[name] + else: + raise ValueError(f"No formatter registered under the name '{name}'.") + + def get_chat_completion_handler_by_name( + self, name: str + ) -> LlamaChatCompletionHandler: + try: + chat_handler = self._chat_handlers[name] + return chat_handler + except KeyError: + raise LlamaChatCompletionHandlerNotFoundException( + f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})" + ) def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: - return CHAT_HANDLERS[name] + return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name( + name + ) def register_chat_completion_handler(name: str): def decorator(f: LlamaChatCompletionHandler): - CHAT_HANDLERS[name] = f + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f) return f return decorator +### Chat Formatter ### + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +class ChatFormatter(Protocol): + """Base Protocol for a chat formatter. A chat formatter is a function that + takes a list of messages and returns a formatted prompt. It can also return + a stop token or list of stop tokens to use for the completion.""" + + def __call__( + self, + *, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + ... + + +### Utility functions for formatting chat prompts ### + + def _get_system_message( messages: List[llama_types.ChatCompletionRequestMessage], ) -> str: @@ -80,14 +153,18 @@ def _get_system_message( def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] + messages: List[llama_types.ChatCompletionRequestMessage], + role_map: Dict[str, str], ) -> List[Tuple[str, Optional[str]]]: """Map the message roles.""" output: List[Tuple[str, Optional[str]]] = [] for message in messages: role = message["role"] if role in role_map: - output.append((role_map[role], message["content"])) + content: str | None = ( + message["content"] if isinstance(message["content"], str) else None + ) + output.append((role_map[role], content)) return output @@ -99,7 +176,8 @@ def _format_llama2( ret = system_message + sep for i, (role, message) in enumerate(messages): if system_message and i == 0: - ret += message + seps[i % 2] + m = message or "" + ret += m + seps[i % 2] elif message: ret += role + message + " " + seps[i % 2] else: @@ -172,6 +250,7 @@ def _format_chatml( ret += role + "\n" return ret + def _format_chatglm3( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -187,30 +266,10 @@ def _format_chatglm3( return ret -@dataclasses.dataclass -class ChatFormatterResponse: - prompt: str - stop: Optional[Union[str, List[str]]] = None - - -class ChatFormatter(Protocol): - def __call__( - self, - *, - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, - ) -> ChatFormatterResponse: - ... - - -class BasicChatHandler: - def __init__(self, chat_format: str): - self.chat_format = chat_format - - def _convert_text_completion_to_chat( completion: llama_types.Completion, ) -> llama_types.ChatCompletion: + assert "usage" in completion return { "id": "chat" + completion["id"], "object": "chat.completion", @@ -286,103 +345,95 @@ def _convert_completion_to_chat( return _convert_text_completion_to_chat(completion) -_CHAT_FORMATS: Dict[str, ChatFormatter] = {} +def chat_formatter_to_chat_completion_handler( + chat_formatter: ChatFormatter, +) -> LlamaChatCompletionHandler: + def chat_completion_handler( + *, + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + seed: Optional[int] = None, + response_format: Optional[ + llama_types.ChatCompletionRequestResponseFormat + ] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, + **kwargs, # type: ignore + ) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], + ]: + result = chat_formatter( + messages=messages, + functions=functions, + function_call=function_call, + ) + prompt = result.prompt + if result.stop is not None: + stop = [] if stop is None else [stop] if isinstance(stop, str) else stop + rstop = result.stop if isinstance(result.stop, list) else [result.stop] + stop = stop + rstop + + if response_format is not None and response_format["type"] == "json_object": + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=stream, + stop=stop, + seed=seed, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + logit_bias=logit_bias, + ) + return _convert_completion_to_chat(completion_or_chunks, stream=stream) + + return chat_completion_handler def register_chat_format(name: str): def decorator(f: ChatFormatter): - def basic_create_chat_completion( - *, - llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[ - llama_types.ChatCompletionRequestFunctionCall - ] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - min_p: float = 0.05, - typical_p: float = 1.0, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - seed: Optional[int] = None, - response_format: Optional[ - llama_types.ChatCompletionRequestResponseFormat - ] = None, - max_tokens: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, - **kwargs, # type: ignore - ) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], - ]: - result = f( - messages=messages, - functions=functions, - function_call=function_call, - ) - prompt = result.prompt - if result.stop is not None: - stop = [] if stop is None else [stop] if isinstance(stop, str) else stop - rstop = result.stop if isinstance(result.stop, list) else [result.stop] - stop = stop + rstop - - if response_format is not None and response_format["type"] == "json_object": - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) - - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - seed=seed, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - logit_bias=logit_bias, - ) - return _convert_completion_to_chat(completion_or_chunks, stream=stream) - - register_chat_completion_handler(name)(basic_create_chat_completion) - return f - - return decorator - - -def get_chat_format(name: str): - try: - return _CHAT_FORMATS[name] - except KeyError: - raise ValueError( - f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})" + chat_completion_handler = chat_formatter_to_chat_completion_handler(f) + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( + name, chat_completion_handler ) + return f + return decorator def hf_autotokenizer_to_chat_formatter( @@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter( # https://huggingface.co/docs/transformers/main/chat_templating # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json - from transformers import AutoTokenizer + from transformers import AutoTokenizer # type: ignore - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore def format_autotokenizer( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - tokenizer.use_default_system_prompt = False - _prompt = tokenizer.apply_chat_template(messages, tokenize=False) + tokenizer.use_default_system_prompt = False # type: ignore + prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore + assert isinstance(prompt, str) # Return formatted prompt and eos token by default - return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) + return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token) return format_autotokenizer +def hf_autotokenizer_to_chat_completion_handler( + pretrained_model_name_or_path: Union[str, os.PathLike[str]] +) -> LlamaChatCompletionHandler: + chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path) + return chat_formatter_to_chat_completion_handler(chat_formatter) + + +def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter: + assert isinstance(tokenizer_config, dict) + + assert "chat_template" in tokenizer_config + assert isinstance(tokenizer_config["chat_template"], str) + chat_template = tokenizer_config["chat_template"] + + assert "bos_token" in tokenizer_config + assert isinstance(tokenizer_config["bos_token"], str) + bos_token = tokenizer_config["bos_token"] + + assert "eos_token" in tokenizer_config + assert isinstance(tokenizer_config["eos_token"], str) + eos_token = tokenizer_config["eos_token"] + + env = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string(chat_template) + + def format_autotokenizer( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + # TODO: veryify this is correct + # Add a blank assistant message to the end of the messages to prompt the model to generate a response + prompt = env.render( + messages=[ + *messages, + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content="" + ), + ], + bos_token=bos_token, + eos_token=eos_token, + ) + return ChatFormatterResponse(prompt=prompt, stop=eos_token) + return format_autotokenizer + + +def hf_tokenizer_config_to_chat_completion_handler( + tokenizer_config: Dict[str, Any], +) -> LlamaChatCompletionHandler: + chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config) + return chat_formatter_to_chat_completion_handler(chat_formatter) + + # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # system prompt is "embedded" in the first message @register_chat_format("llama-2") @@ -437,21 +544,23 @@ def format_alpaca( _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) return ChatFormatterResponse(prompt=_prompt) + @register_chat_format("qwen") def format_qwen( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") - system_message="You are a helpful assistant." - system_template="<|im_start|>system\n{system_message}" - system_message=system_template.format(system_message=system_message) + system_message = "You are a helpful assistant." + system_template = "<|im_start|>system\n{system_message}" + system_message = system_template.format(system_message=system_message) _messages = _map_roles(messages, _roles) _messages.append((_roles["assistant"], None)) _sep = "<|im_end|>" _prompt = _format_chatml(system_message, _messages, _sep) _sep2 = "<|endoftext|>" - return ChatFormatterResponse(prompt=_prompt,stop=_sep2) + return ChatFormatterResponse(prompt=_prompt, stop=_sep2) + @register_chat_format("vicuna") def format( @@ -650,6 +759,7 @@ def format_mistrallite( _prompt = _format_no_colon_single(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + @register_chat_format("zephyr") def format_zephyr( messages: List[llama_types.ChatCompletionRequestMessage], @@ -699,6 +809,7 @@ def format_chatml( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep) + @register_chat_format("chatglm3") def format_chatglm3( messages: List[llama_types.ChatCompletionRequestMessage], @@ -739,7 +850,7 @@ def format_openchat( @register_chat_format("saiga") def format_saiga( messages: list[llama_types.ChatCompletionRequestMessage], - **kwargs, + **kwargs: Any, ) -> ChatFormatterResponse: _message_template = "{role}\n{content}" _roles = dict(user="user", bot="bot", system="system") diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py deleted file mode 100644 index 68faaf6..0000000 --- a/llama_cpp/llama_jinja_format.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -llama_cpp/llama_jinja_format.py -""" -import dataclasses -from typing import Any, Callable, Dict, List, Optional, Protocol, Union - -import jinja2 -from jinja2 import Template - -# NOTE: We sacrifice readability for usability. -# It will fail to work as expected if we attempt to format it in a readable way. -llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" - - -class MetaSingleton(type): - """ - Metaclass for implementing the Singleton pattern. - """ - - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - -class Singleton(object, metaclass=MetaSingleton): - """ - Base class for implementing the Singleton pattern. - """ - - def __init__(self): - super(Singleton, self).__init__() - - -@dataclasses.dataclass -class ChatFormatterResponse: - prompt: str - stop: Optional[Union[str, List[str]]] = None - - -# Base Chat Formatter Protocol -class ChatFormatterInterface(Protocol): - def __init__(self, template: Optional[object] = None): - ... - - def __call__( - self, - messages: List[Dict[str, str]], - **kwargs, - ) -> ChatFormatterResponse: - ... - - @property - def template(self) -> str: - ... - - -class AutoChatFormatter(ChatFormatterInterface): - def __init__( - self, - template: Optional[str] = None, - template_class: Optional[Template] = None, - ): - if template is not None: - self._template = template - else: - self._template = llama2_template # default template - - self._environment = jinja2.Environment( - loader=jinja2.BaseLoader(), - trim_blocks=True, - lstrip_blocks=True, - ).from_string( - self._template, - template_class=template_class, - ) - - def __call__( - self, - messages: List[Dict[str, str]], - **kwargs: Any, - ) -> ChatFormatterResponse: - formatted_sequence = self._environment.render(messages=messages, **kwargs) - return ChatFormatterResponse(prompt=formatted_sequence) - - @property - def template(self) -> str: - return self._template - - -class FormatterNotFoundException(Exception): - pass - - -class ChatFormatterFactory(Singleton): - _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} - - def register_formatter( - self, - name: str, - formatter_callable: Callable[[], ChatFormatterInterface], - overwrite=False, - ): - if not overwrite and name in self._chat_formatters: - raise ValueError( - f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." - ) - self._chat_formatters[name] = formatter_callable - - def unregister_formatter(self, name: str): - if name in self._chat_formatters: - del self._chat_formatters[name] - else: - raise ValueError(f"No formatter registered under the name '{name}'.") - - def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: - try: - formatter_callable = self._chat_formatters[name] - return formatter_callable() - except KeyError: - raise FormatterNotFoundException( - f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" - ) - - -# Define a chat format class -class Llama2Formatter(AutoChatFormatter): - def __init__(self): - super().__init__(llama2_template) - - -# With the Singleton pattern applied, regardless of where or how many times -# ChatFormatterFactory() is called, it will always return the same instance -# of the factory, ensuring that the factory's state is consistent throughout -# the application. -ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index f9be323..c2d6b6d 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + from typing import Dict, Optional, Union, List import llama_cpp @@ -71,7 +73,25 @@ class LlamaProxy: chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) - + elif settings.chat_format == "hf-autotokenizer": + assert ( + settings.hf_pretrained_model_name_or_path is not None + ), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer" + chat_handler = ( + llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter( + settings.hf_pretrained_model_name_or_path + ) + ) + elif settings.chat_format == "hf-tokenizer-config": + assert ( + settings.hf_tokenizer_config_path is not None + ), "hf_tokenizer_config_path must be set for hf-tokenizer-config" + chat_handler = ( + llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter( + json.load(open(settings.hf_tokenizer_config_path)) + ) + ) + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None if settings.kv_overrides is not None: assert isinstance(settings.kv_overrides, list) @@ -141,4 +161,3 @@ class LlamaProxy: cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) _model.set_cache(cache) return _model - diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index dc5be20..9f0dc8a 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -134,6 +134,15 @@ class ModelSettings(BaseSettings): default=2 << 30, description="The size of the cache in bytes. Only used if cache is True.", ) + # Tokenizer Options + hf_tokenizer_config_path: Optional[str] = Field( + default=None, + description="The path to a HuggingFace tokenizer_config.json file.", + ) + hf_pretrained_model_name_or_path: Optional[str] = Field( + default=None, + description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information." diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 4eebcb6..1ef18d9 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,50 +1,65 @@ -from typing import List +import json -import pytest - -from llama_cpp import ChatCompletionMessage -from llama_cpp.llama_jinja_format import Llama2Formatter +from llama_cpp import ( + ChatCompletionRequestUserMessage, +) +from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter -@pytest.fixture -def sequence_of_messages() -> List[ChatCompletionMessage]: - return [ - ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), - ChatCompletionMessage( - role="user", content="Hi there! I need some help with Python." - ), - ChatCompletionMessage( - role="assistant", content="Of course! What do you need help with in Python?" - ), - ChatCompletionMessage( - role="user", - content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", - ), - ChatCompletionMessage( - role="assistant", - content="I can help with that! Would you like a recursive or iterative solution?", - ), - ChatCompletionMessage( - role="user", content="Let's go with a recursive solution." - ), - ] +mistral_7b_tokenizer_config = """{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [], + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": true, + "model_max_length": 1000000000000000019884624838656, + "pad_token": null, + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false, + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +}""" -def test_llama2_formatter(sequence_of_messages): - expected_prompt = ( - "<> Welcome to CodeHelp Bot! <>\n" - "[INST] Hi there! I need some help with Python. [/INST]\n" - "Of course! What do you need help with in Python?\n" - "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" - "I can help with that! Would you like a recursive or iterative solution?\n" - "[INST] Let's go with a recursive solution. [/INST]\n" +def test_hf_tokenizer_config_str_to_chat_formatter(): + tokenizer_config = json.loads(mistral_7b_tokenizer_config) + chat_formatter = hf_tokenizer_config_to_chat_formatter( + tokenizer_config + ) + chat_formatter_respoonse = chat_formatter( + messages=[ + ChatCompletionRequestUserMessage(role="user", content="Hello, world!"), + ] ) - llama2_formatter_instance = Llama2Formatter() - formatter_response = llama2_formatter_instance(sequence_of_messages) - assert ( - expected_prompt == formatter_response.prompt - ), "The formatted prompt does not match the expected output." - - -# Optionally, include a test for the 'stop' if it's part of the functionality. + assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "") From 89cce50f8c332cdb72636d2f61e37a1309feafca Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:21:49 -0500 Subject: [PATCH 08/26] Update llama.cpp --- llama_cpp/llama_cpp.py | 13 +++++++++++++ vendor/llama.cpp | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 9e8e3ce..ef16272 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -91,6 +91,12 @@ c_float_p = POINTER(c_float) c_uint8_p = POINTER(c_uint8) c_size_t_p = POINTER(c_size_t) +# from ggml-backend.h +# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); +ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE( + c_bool, c_void_p, c_bool, c_void_p +) + # llama.h bindings _lib.llama_max_devices.argtypes = [] @@ -448,6 +454,9 @@ class llama_model_params(Structure): # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size +# ggml_backend_sched_eval_callback cb_eval; +# void * cb_eval_user_data; + # enum ggml_type type_k; // data type for K cache # enum ggml_type type_v; // data type for V cache @@ -475,6 +484,8 @@ class llama_context_params(Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size + cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval + cb_eval_user_data (ctypes.c_void_p): user data for cb_eval type_k (int): data type for K cache type_v (int): data type for V cache mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true) @@ -497,6 +508,8 @@ class llama_context_params(Structure): ("yarn_beta_fast", c_float), ("yarn_beta_slow", c_float), ("yarn_orig_ctx", c_uint32), + ("cb_eval", ggml_backend_sched_eval_callback), + ("cb_eval_user_data", c_void_p), ("type_k", c_int), ("type_v", c_int), ("mul_mat_q", c_bool), diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4f4bf35..2d5419d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4f4bf35f46600441dec2f941e667291eeb9a18d8 +Subproject commit 2d5419d08ab1131623e6a1d554607b7663435e87 From be23404ed43b807406a6db55a0cbb830b47b11e3 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:22:19 -0500 Subject: [PATCH 09/26] Cleanup pyproject --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 806127d..4130972 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,6 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] -# mkdocs-martiral requires "jinja2~=3.0" -# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0", From 3ca86ab3909444917a272607aa62b105584a50a8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:22:45 -0500 Subject: [PATCH 10/26] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 2d5419d..57e2a7a 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 2d5419d08ab1131623e6a1d554607b7663435e87 +Subproject commit 57e2a7a52a819883f40dada8a2edc24ecf48186b From 03ed547bfd6907039e6d594e7203dca155280492 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:23:26 -0500 Subject: [PATCH 11/26] Remove templates doc --- docs/templates.md | 52 ----------------------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 docs/templates.md diff --git a/docs/templates.md b/docs/templates.md deleted file mode 100644 index 5acdaa1..0000000 --- a/docs/templates.md +++ /dev/null @@ -1,52 +0,0 @@ -# Templates - -This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model. - -## Introduction - -- Brief explanation of the `llama-cpp-python` project's need for a templating system. -- Overview of the `llama-2` model's interaction with templating. - -## Jinja2 Dependency Integration - -- Rationale for choosing Jinja2 as the templating engine. - - Compatibility with Hugging Face's `transformers`. - - Desire for advanced templating features and simplicity. -- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management. - -## Template Management Refactor - -- Summary of the refactor and the motivation behind it. -- Description of the new chat handler selection logic: - 1. Preference for a user-specified `chat_handler`. - 2. Fallback to a user-specified `chat_format`. - 3. Defaulting to a chat format from a `.gguf` file if available. - 4. Utilizing the `llama2` default chat format as the final fallback. -- Ensuring backward compatibility throughout the refactor. - -## Implementation Details - -- In-depth look at the new `AutoChatFormatter` class. -- Example code snippets showing how to utilize the Jinja2 environment and templates. -- Guidance on how to provide custom templates or use defaults. - -## Testing and Validation - -- Outline of the testing strategy to ensure seamless integration. -- Steps for validating backward compatibility with existing implementations. - -## Benefits and Impact - -- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience. -- Discussion of the potential impact on current users and contributors. - -## Future Work - -- Exploration of how templating can evolve within the project. -- Consideration of additional features or optimizations for the templating engine. -- Mechanisms for community feedback on the templating system. - -## Conclusion - -- Final thoughts on the integration of Jinja2 templating. -- Call to action for community involvement and feedback. From 656f3d896845d428f994dd205b28d8f524cbff67 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Jan 2024 21:30:36 -0500 Subject: [PATCH 12/26] Bump version --- CHANGELOG.md | 9 +++++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0748ee..92636ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.30] + +- feat: Update llama.cpp to ggerganov/llama.cpp@57e2a7a52a819883f40dada8a2edc24ecf48186b +- feat(server): Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files by @abetlen in b8fc1c7d83ad4a9207c707ba1d954fe580286a01 +- feat: Integration of Jinja2 Templating for chat formats by @teleprint-me in #875 +- fix: Offload KQV by default by @abetlen in 48c3b77e6f558a9899de0e1155c7dc0c7958d8e8 +- fix: Support Accept text/event-stream in chat and completion endpoints, resolves #1083 by @aniljava in #1088 +- fix(cli): allow passing n_ctx=0 to openAI API server args to use model n_ctx_train field per #1015 by @K-Mistele in #1093 + ## [0.2.29] - feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 65206bf..210404a 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.29" \ No newline at end of file +__version__ = "0.2.30" \ No newline at end of file From 141293a75b564a8699e0acba1da24d9aa1cf0ab1 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 08:17:49 -0500 Subject: [PATCH 13/26] Fix python3.8 support --- llama_cpp/server/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/server/cli.py b/llama_cpp/server/cli.py index 8e32d2c..3dd0076 100644 --- a/llama_cpp/server/cli.py +++ b/llama_cpp/server/cli.py @@ -55,7 +55,7 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool: raise ValueError(f"Invalid boolean argument: {arg}") -def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]): +def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]): """Add arguments from a pydantic model to an argparse parser.""" for name, field in model.model_fields.items(): @@ -83,7 +83,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]) ) -T = TypeVar("T", bound=type[BaseModel]) +T = TypeVar("T", bound=Type[BaseModel]) def parse_model_from_args(model: T, args: argparse.Namespace) -> T: From 3babe3512cb95743108f2b595210c38ed6f1b904 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 08:31:59 -0500 Subject: [PATCH 14/26] Fix mirostat sampling --- llama_cpp/llama.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6cdc1eb..32eb3fe 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -329,6 +329,8 @@ class Llama: (n_ctx, self._n_vocab), dtype=np.single ) + self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None @@ -516,7 +518,7 @@ class Llama: candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(self._mirostat_mu), m=100, ) elif mirostat_mode == 2: @@ -525,7 +527,7 @@ class Llama: candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(self._mirostat_mu) ) else: self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) @@ -581,6 +583,10 @@ class Llama: Yields: The generated tokens. """ + # Reset mirostat sampling + self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau) + + # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -595,12 +601,15 @@ class Llama: tokens = tokens[longest_prefix:] self.n_tokens = longest_prefix + # Reset the model state if reset: self.reset() + # Reset the grammar if grammar is not None: grammar.reset() + # Eval and sample while True: self.eval(tokens) token = self.sample( From 0f54948482fc7e2ad71f511c1536b7a77c5c1ba7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 08:41:52 -0500 Subject: [PATCH 15/26] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 57e2a7a..a5cacb2 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 57e2a7a52a819883f40dada8a2edc24ecf48186b +Subproject commit a5cacb22b2114fd9adf61c00cbb237384d86bced From e21c3c7a91f354cdb139f4bd6a5d48193ea806f0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 08:47:56 -0500 Subject: [PATCH 16/26] Update makefile --- Makefile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index e930609..5ed3fa2 100644 --- a/Makefile +++ b/Makefile @@ -10,22 +10,22 @@ deps: python3 -m pip install -e ".[all]" build: - python3 -m pip install -e . + python3 -m pip install --verbose -e . build.cuda: - CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install -e . + CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install --verbose -e . build.opencl: - CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e . + CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e . build.openblas: - CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e . + CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e . build.blis: - CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install -e . + CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install --verbose -e . build.metal: - CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install -e . + CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e . build.sdist: python3 -m build --sdist From 833a7f1a86f2136df5f75c1bd62d2e4d5adaa439 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 09:03:35 -0500 Subject: [PATCH 17/26] Bump version --- CHANGELOG.md | 6 ++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92636ba..797785e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.31] + +- feat: Update llama.cpp to ggerganov/llama.cpp@a5cacb22b2114fd9adf61c00cbb237384d86bced +- fix: Mirostat sampling now passes correct type to ctypes and tracks state during generation by @abetlen in 3babe3512cb95743108f2b595210c38ed6f1b904 +- fix: Python3.8 support in server by @abetlen in 141293a75b564a8699e0acba1da24d9aa1cf0ab1 + ## [0.2.30] - feat: Update llama.cpp to ggerganov/llama.cpp@57e2a7a52a819883f40dada8a2edc24ecf48186b diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 210404a..1869838 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.30" \ No newline at end of file +__version__ = "0.2.31" \ No newline at end of file From 5a34c57e5479e50c99aba9b38218cc48e6560b81 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 10:46:03 -0500 Subject: [PATCH 18/26] feat: Expose gguf model metadata in metadata property --- llama_cpp/_internals.py | 25 +++++++++++++++++++++++++ llama_cpp/llama.py | 10 ++++++++++ 2 files changed, 35 insertions(+) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 208de8c..ec47c42 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -204,6 +204,31 @@ class _LlamaModel: output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output ) + # Extra + def metadata(self) -> Dict[str, str]: + assert self.model is not None + metadata: Dict[str, str] = {} + buffer_size = 1024 + buffer = ctypes.create_string_buffer(buffer_size) + # zero the buffer + buffer.value = b'\0' * buffer_size + # iterate over model keys + for i in range(llama_cpp.llama_model_meta_count(self.model)): + nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) + if nbytes > buffer_size: + buffer_size = nbytes + buffer = ctypes.create_string_buffer(buffer_size) + nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) + key = buffer.value.decode("utf-8") + nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) + if nbytes > buffer_size: + buffer_size = nbytes + buffer = ctypes.create_string_buffer(buffer_size) + nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) + value = buffer.value.decode("utf-8") + metadata[key] = value + return metadata + @staticmethod def default_params(): """Get the default llama_model_params.""" diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 32eb3fe..5c66bcf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -331,6 +331,16 @@ class Llama: self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + try: + self.metadata = self._model.metadata() + except Exception as e: + self.metadata = {} + if self.verbose: + print(f"Failed to load metadata: {e}", file=sys.stderr) + + if self.verbose: + print(f"Model metadata: {self.metadata}", file=sys.stderr) + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None From be09318c26add8674ce494ae7cc480cce72a4146 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 15:04:42 -0500 Subject: [PATCH 19/26] feat: Add Jinja2ChatFormatter --- llama_cpp/llama_chat_format.py | 323 +++++++++++++++++++-------------- 1 file changed, 188 insertions(+), 135 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3d18d90..4e1b174 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -121,14 +121,21 @@ def register_chat_completion_handler(name: str): @dataclasses.dataclass class ChatFormatterResponse: + """Dataclass that stores completion parameters for a given chat format and + create_chat_completion request. + + prompt contains the formatted prompt generated from the chat format and messages. + stop contains the stop token or list of stop tokens to use for the chat format.""" + prompt: str stop: Optional[Union[str, List[str]]] = None class ChatFormatter(Protocol): """Base Protocol for a chat formatter. A chat formatter is a function that - takes a list of messages and returns a formatted prompt. It can also return - a stop token or list of stop tokens to use for the completion.""" + takes a list of messages and returns a chat format response which can be used + to generate a completion. The response can also include a stop token or list + of stop tokens to use for the completion.""" def __call__( self, @@ -139,131 +146,43 @@ class ChatFormatter(Protocol): ... -### Utility functions for formatting chat prompts ### +class Jinja2ChatFormatter(ChatFormatter): + def __init__( + self, + template: str, + eos_token: str, + bos_token: str, + ): + """A chat formatter that uses jinja2 templates to format the prompt.""" + self.template = template + self.eos_token = eos_token + self.bos_token = bos_token + self._environment = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string(self.template) -def _get_system_message( - messages: List[llama_types.ChatCompletionRequestMessage], -) -> str: - """Get the first system message.""" - for message in messages: - if message["role"] == "system": - return message["content"] or "" - return "" + def __call__( + self, + *, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + messages = [ + *messages, + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content="" + ), + ] + prompt = self._environment.render( + messages=messages, eos_token=self.eos_token, bos_token=self.bos_token + ) + return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) - -def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], - role_map: Dict[str, str], -) -> List[Tuple[str, Optional[str]]]: - """Map the message roles.""" - output: List[Tuple[str, Optional[str]]] = [] - for message in messages: - role = message["role"] - if role in role_map: - content: str | None = ( - message["content"] if isinstance(message["content"], str) else None - ) - output.append((role_map[role], content)) - return output - - -def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the llama2 style.""" - seps = [sep, sep2] - ret = system_message + sep - for i, (role, message) in enumerate(messages): - if system_message and i == 0: - m = message or "" - ret += m + seps[i % 2] - elif message: - ret += role + message + " " + seps[i % 2] - else: - ret += role + " " - return ret - - -def _format_add_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ":" - return ret - - -def _format_add_colon_two( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the add-colon-two style.""" - seps = [sep, sep2] - ret = system_message + seps[0] - for i, (role, message) in enumerate(messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - - -def _format_no_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the no-colon-single style.""" - ret = system_message - for role, message in messages: - if message: - ret += role + message + sep - else: - ret += role - return ret - - -def _format_add_colon_space_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-space-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ": " # must be end with a space - return ret - - -def _format_chatml( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatml style.""" - ret = "" if system_message == "" else system_message + sep + "\n" - for role, message in messages: - if message: - ret += role + "\n" + message + sep + "\n" - else: - ret += role + "\n" - return ret - - -def _format_chatglm3( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatglm3 style.""" - ret = "" - if system_message: - ret += system_message - for role, message in messages: - if message: - ret += role + "\n" + " " + message - else: - ret += role - return ret + def to_chat_handler(self) -> LlamaChatCompletionHandler: + return chat_formatter_to_chat_completion_handler(self) def _convert_text_completion_to_chat( @@ -426,16 +345,6 @@ def chat_formatter_to_chat_completion_handler( return chat_completion_handler -def register_chat_format(name: str): - def decorator(f: ChatFormatter): - chat_completion_handler = chat_formatter_to_chat_completion_handler(f) - LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( - name, chat_completion_handler - ) - return f - return decorator - - def hf_autotokenizer_to_chat_formatter( pretrained_model_name_or_path: Union[str, os.PathLike[str]] ) -> ChatFormatter: @@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) -def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter: +def hf_tokenizer_config_to_chat_formatter( + tokenizer_config: Dict[str, Any] +) -> ChatFormatter: assert isinstance(tokenizer_config, dict) assert "chat_template" in tokenizer_config @@ -504,6 +415,7 @@ def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> C eos_token=eos_token, ) return ChatFormatterResponse(prompt=prompt, stop=eos_token) + return format_autotokenizer @@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) +### Utility functions for formatting chat prompts ### + + +def _get_system_message( + messages: List[llama_types.ChatCompletionRequestMessage], +) -> str: + """Get the first system message.""" + for message in messages: + if message["role"] == "system": + return message["content"] or "" + return "" + + +def _map_roles( + messages: List[llama_types.ChatCompletionRequestMessage], + role_map: Dict[str, str], +) -> List[Tuple[str, Optional[str]]]: + """Map the message roles.""" + output: List[Tuple[str, Optional[str]]] = [] + for message in messages: + role = message["role"] + if role in role_map: + content: str | None = ( + message["content"] if isinstance(message["content"], str) else None + ) + output.append((role_map[role], content)) + return output + + +def _format_llama2( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str +) -> str: + """Format the prompt with the llama2 style.""" + seps = [sep, sep2] + ret = system_message + sep + for i, (role, message) in enumerate(messages): + if system_message and i == 0: + m = message or "" + ret += m + seps[i % 2] + elif message: + ret += role + message + " " + seps[i % 2] + else: + ret += role + " " + return ret + + +def _format_add_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ":" + return ret + + +def _format_add_colon_two( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str +) -> str: + """Format the prompt with the add-colon-two style.""" + seps = [sep, sep2] + ret = system_message + seps[0] + for i, (role, message) in enumerate(messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + + +def _format_no_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the no-colon-single style.""" + ret = system_message + for role, message in messages: + if message: + ret += role + message + sep + else: + ret += role + return ret + + +def _format_add_colon_space_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-space-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ": " # must be end with a space + return ret + + +def _format_chatml( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatml style.""" + ret = "" if system_message == "" else system_message + sep + "\n" + for role, message in messages: + if message: + ret += role + "\n" + message + sep + "\n" + else: + ret += role + "\n" + return ret + + +def _format_chatglm3( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatglm3 style.""" + ret = "" + if system_message: + ret += system_message + for role, message in messages: + if message: + ret += role + "\n" + " " + message + else: + ret += role + return ret + + +### Chat Formats ### + + +def register_chat_format(name: str): + def decorator(f: ChatFormatter): + chat_completion_handler = chat_formatter_to_chat_completion_handler(f) + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( + name, chat_completion_handler + ) + return f + + return decorator + + # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # system prompt is "embedded" in the first message @register_chat_format("llama-2") From ac2e96d4b4610cb9cd9b0c978c76ece6567f5c02 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 15:33:43 -0500 Subject: [PATCH 20/26] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index a5cacb2..381ee19 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit a5cacb22b2114fd9adf61c00cbb237384d86bced +Subproject commit 381ee195721d8e747ee31a60c0751822b3072f02 From 7f3209b1eb4ad3260ba063801fab80a8c25a2f4c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 18:37:24 -0500 Subject: [PATCH 21/26] feat: Add add_generation_prompt option for jinja2chatformatter. --- llama_cpp/llama_chat_format.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 4e1b174..02bdbcf 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter): template: str, eos_token: str, bos_token: str, + add_generation_prompt: bool = True, ): """A chat formatter that uses jinja2 templates to format the prompt.""" self.template = template self.eos_token = eos_token self.bos_token = bos_token + self.add_generation_prompt = add_generation_prompt self._environment = jinja2.Environment( loader=jinja2.BaseLoader(), @@ -170,12 +172,13 @@ class Jinja2ChatFormatter(ChatFormatter): messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - messages = [ - *messages, - llama_types.ChatCompletionRequestAssistantMessage( - role="assistant", content="" - ), - ] + if self.add_generation_prompt: + messages = [ + *messages, + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content="" + ), + ] prompt = self._environment.render( messages=messages, eos_token=self.eos_token, bos_token=self.bos_token ) From 24f39454e91cf5dddbc4b6041aead4accc7c7a2d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 18:38:04 -0500 Subject: [PATCH 22/26] fix: pass chat handler not chat formatter for huggingface autotokenizer and tokenizer_config formats. --- llama_cpp/server/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index c2d6b6d..bbb6806 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -78,7 +78,7 @@ class LlamaProxy: settings.hf_pretrained_model_name_or_path is not None ), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer" chat_handler = ( - llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter( + llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_completion_handler( settings.hf_pretrained_model_name_or_path ) ) @@ -87,7 +87,7 @@ class LlamaProxy: settings.hf_tokenizer_config_path is not None ), "hf_tokenizer_config_path must be set for hf-tokenizer-config" chat_handler = ( - llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter( + llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler( json.load(open(settings.hf_tokenizer_config_path)) ) ) From 88fbccaaa3416f38552d80d84c71fdb40c7c477a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 18:38:44 -0500 Subject: [PATCH 23/26] docs: Add macosx wrong arch fix to README --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ad5d0f1..f97ea0f 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,10 @@ See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to ### MacOS Notes +Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/) + +#### M1 Mac Performance Issue + Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example: ``` wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh @@ -120,7 +124,13 @@ bash Miniforge3-MacOSX-arm64.sh ``` Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac. -Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/) +#### M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))` + +Try installing with + +``` +CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_APPLE_SILICON_PROCESSOR=arm64 -DLLAMA_METAL=on" pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python +``` ### Upgrading and Reinstalling From 8eefdbca03d005095f6645d4e5b42b982af9daf0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 19:01:27 -0500 Subject: [PATCH 24/26] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 381ee19..504dc37 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 381ee195721d8e747ee31a60c0751822b3072f02 +Subproject commit 504dc37be8446fb09b1ede70300250ad41be32a2 From d3f5528ca8bcb9d69d4f27e21631e911f1fb9bfe Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 19:06:53 -0500 Subject: [PATCH 25/26] fix: from_json_schema oneof/anyof bug. Closes #1097 --- llama_cpp/llama_grammar.py | 23 +++++++++++++---------- tests/test_grammar.py | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index c02e656..d8ef563 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1432,7 +1432,6 @@ class SchemaConverter: return key def visit(self, schema: Dict[str, Any], name: str) -> str: - schema_type: Optional[str] = schema.get("type") # type: ignore rule_name = name or "root" if "$defs" in schema: @@ -1458,7 +1457,19 @@ class SchemaConverter: rule = " | ".join((self._format_literal(v) for v in schema["enum"])) return self._add_rule(rule_name, rule) - elif schema_type == "object" and "properties" in schema: + elif "$ref" in schema: + ref = schema["$ref"] + assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" + # inline $defs + def_name = ref[len("#/$defs/") :] + def_schema = self._defs[def_name] + return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + + + schema_type: Optional[str] = schema.get("type") # type: ignore + assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" + + if schema_type == "object" and "properties" in schema: # TODO: `required` keyword prop_order = self._prop_order prop_pairs = sorted( @@ -1489,14 +1500,6 @@ class SchemaConverter: ) return self._add_rule(rule_name, rule) - elif "$ref" in schema: - ref = schema["$ref"] - assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" - # inline $defs - def_name = ref[len("#/$defs/") :] - def_schema = self._defs[def_name] - return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') - else: assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" return self._add_rule( diff --git a/tests/test_grammar.py b/tests/test_grammar.py index ef9392b..cb22188 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -50,3 +50,29 @@ def test_composed_pydantic_grammar(): grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) assert grammar.grammar is not None + + +def test_grammar_anyof(): + sch = { + "properties": { + "temperature": { + "description": "The temperature mentioned", + "type": "number", + }, + "unit": { + "anyOf": [ + { + "description": "Unit for temperature", + "enum": ["celsius", "fahrenheit"], + "type": "string", + }, + {"type": "null"}, + ], + }, + }, + "type": "object", + } + + grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch)) + + assert grammar.grammar is not None \ No newline at end of file From 2ce0b8aa2c2f81d999bbb2a7246a9f221f9d52d0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Jan 2024 20:30:24 -0500 Subject: [PATCH 26/26] Bump version --- CHANGELOG.md | 9 +++++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 797785e..4fff919 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.32] + +- feat: Update llama.cpp to ggerganov/llama.cpp@504dc37be8446fb09b1ede70300250ad41be32a2 +- fix: from_json_schema oneof/anyof bug by @jndiogo in d3f5528ca8bcb9d69d4f27e21631e911f1fb9bfe +- fix: pass chat handler not chat formatter for huggingface autotokenizer and tokenizer_config formats by @abetlen in 24f39454e91cf5dddbc4b6041aead4accc7c7a2d +- feat: Add add_generation_prompt option for jinja2chatformatter by @abetlen in 7f3209b1eb4ad3260ba063801fab80a8c25a2f4c +- feat: Add Jinja2ChatFormatter by @abetlen in be09318c26add8674ce494ae7cc480cce72a4146 +- feat: Expose gguf model metadata in metadata property by @abetlen in 5a34c57e5479e50c99aba9b38218cc48e6560b81 + ## [0.2.31] - feat: Update llama.cpp to ggerganov/llama.cpp@a5cacb22b2114fd9adf61c00cbb237384d86bced diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 1869838..dda8335 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.31" \ No newline at end of file +__version__ = "0.2.32" \ No newline at end of file