This commit is contained in:
commit
8806f19ef9
18 changed files with 1640 additions and 954 deletions
24
CHANGELOG.md
24
CHANGELOG.md
|
@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.2.32]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@504dc37be8446fb09b1ede70300250ad41be32a2
|
||||||
|
- fix: from_json_schema oneof/anyof bug by @jndiogo in d3f5528ca8bcb9d69d4f27e21631e911f1fb9bfe
|
||||||
|
- fix: pass chat handler not chat formatter for huggingface autotokenizer and tokenizer_config formats by @abetlen in 24f39454e91cf5dddbc4b6041aead4accc7c7a2d
|
||||||
|
- feat: Add add_generation_prompt option for jinja2chatformatter by @abetlen in 7f3209b1eb4ad3260ba063801fab80a8c25a2f4c
|
||||||
|
- feat: Add Jinja2ChatFormatter by @abetlen in be09318c26add8674ce494ae7cc480cce72a4146
|
||||||
|
- feat: Expose gguf model metadata in metadata property by @abetlen in 5a34c57e5479e50c99aba9b38218cc48e6560b81
|
||||||
|
|
||||||
|
## [0.2.31]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@a5cacb22b2114fd9adf61c00cbb237384d86bced
|
||||||
|
- fix: Mirostat sampling now passes correct type to ctypes and tracks state during generation by @abetlen in 3babe3512cb95743108f2b595210c38ed6f1b904
|
||||||
|
- fix: Python3.8 support in server by @abetlen in 141293a75b564a8699e0acba1da24d9aa1cf0ab1
|
||||||
|
|
||||||
|
## [0.2.30]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@57e2a7a52a819883f40dada8a2edc24ecf48186b
|
||||||
|
- feat(server): Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files by @abetlen in b8fc1c7d83ad4a9207c707ba1d954fe580286a01
|
||||||
|
- feat: Integration of Jinja2 Templating for chat formats by @teleprint-me in #875
|
||||||
|
- fix: Offload KQV by default by @abetlen in 48c3b77e6f558a9899de0e1155c7dc0c7958d8e8
|
||||||
|
- fix: Support Accept text/event-stream in chat and completion endpoints, resolves #1083 by @aniljava in #1088
|
||||||
|
- fix(cli): allow passing n_ctx=0 to openAI API server args to use model n_ctx_train field per #1015 by @K-Mistele in #1093
|
||||||
|
|
||||||
## [0.2.29]
|
## [0.2.29]
|
||||||
|
|
||||||
- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b
|
- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b
|
||||||
|
|
12
Makefile
12
Makefile
|
@ -10,22 +10,22 @@ deps:
|
||||||
python3 -m pip install -e ".[all]"
|
python3 -m pip install -e ".[all]"
|
||||||
|
|
||||||
build:
|
build:
|
||||||
python3 -m pip install -e .
|
python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.cuda:
|
build.cuda:
|
||||||
CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install -e .
|
CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.opencl:
|
build.opencl:
|
||||||
CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e .
|
CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.openblas:
|
build.openblas:
|
||||||
CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install -e .
|
CMAKE_ARGS="-DLLAMA_CLBLAST=on" python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.blis:
|
build.blis:
|
||||||
CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install -e .
|
CMAKE_ARGS="-DLLAMA_OPENBLAS=on -DLLAMA_OPENBLAS_VENDOR=blis" python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.metal:
|
build.metal:
|
||||||
CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install -e .
|
CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e .
|
||||||
|
|
||||||
build.sdist:
|
build.sdist:
|
||||||
python3 -m build --sdist
|
python3 -m build --sdist
|
||||||
|
|
12
README.md
12
README.md
|
@ -113,6 +113,10 @@ See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to
|
||||||
|
|
||||||
### MacOS Notes
|
### MacOS Notes
|
||||||
|
|
||||||
|
Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
|
||||||
|
|
||||||
|
#### M1 Mac Performance Issue
|
||||||
|
|
||||||
Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example:
|
Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example:
|
||||||
```
|
```
|
||||||
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
|
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
|
||||||
|
@ -120,7 +124,13 @@ bash Miniforge3-MacOSX-arm64.sh
|
||||||
```
|
```
|
||||||
Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
|
Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
|
||||||
|
|
||||||
Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
|
#### M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`
|
||||||
|
|
||||||
|
Try installing with
|
||||||
|
|
||||||
|
```
|
||||||
|
CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_APPLE_SILICON_PROCESSOR=arm64 -DLLAMA_METAL=on" pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python
|
||||||
|
```
|
||||||
|
|
||||||
### Upgrading and Reinstalling
|
### Upgrading and Reinstalling
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .llama_cpp import *
|
from .llama_cpp import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
|
|
||||||
__version__ = "0.2.29"
|
__version__ = "0.2.32"
|
795
llama_cpp/_internals.py
Normal file
795
llama_cpp/_internals.py
Normal file
|
@ -0,0 +1,795 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from .llama_types import *
|
||||||
|
from .llama_grammar import LlamaGrammar
|
||||||
|
|
||||||
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
|
|
||||||
|
from ._utils import suppress_stdout_stderr
|
||||||
|
|
||||||
|
|
||||||
|
# Python wrappers over llama.h structs
|
||||||
|
|
||||||
|
|
||||||
|
class _LlamaModel:
|
||||||
|
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
||||||
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||||
|
|
||||||
|
_llama_free_model = None
|
||||||
|
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||||
|
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
path_model: str,
|
||||||
|
params: llama_cpp.llama_model_params,
|
||||||
|
verbose: bool = True,
|
||||||
|
):
|
||||||
|
self.path_model = path_model
|
||||||
|
self.params = params
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
|
||||||
|
|
||||||
|
if not os.path.exists(path_model):
|
||||||
|
raise ValueError(f"Model path does not exist: {path_model}")
|
||||||
|
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
self.model = llama_cpp.llama_load_model_from_file(
|
||||||
|
self.path_model.encode("utf-8"), self.params
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
if self.model is not None and self._llama_free_model is not None:
|
||||||
|
self._llama_free_model(self.model)
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def vocab_type(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_vocab_type(self.model)
|
||||||
|
|
||||||
|
def n_vocab(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_n_vocab(self.model)
|
||||||
|
|
||||||
|
def n_ctx_train(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_n_ctx_train(self.model)
|
||||||
|
|
||||||
|
def n_embd(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_n_embd(self.model)
|
||||||
|
|
||||||
|
def rope_freq_scale_train(self) -> float:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_rope_freq_scale_train(self.model)
|
||||||
|
|
||||||
|
def desc(self) -> str:
|
||||||
|
assert self.model is not None
|
||||||
|
buf = ctypes.create_string_buffer(1024)
|
||||||
|
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
|
||||||
|
return buf.value.decode("utf-8")
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_model_size(self.model)
|
||||||
|
|
||||||
|
def n_params(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_model_n_params(self.model)
|
||||||
|
|
||||||
|
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
|
||||||
|
|
||||||
|
def apply_lora_from_file(
|
||||||
|
self,
|
||||||
|
lora_path: str,
|
||||||
|
scale: float,
|
||||||
|
path_base_model: Optional[str],
|
||||||
|
n_threads: int,
|
||||||
|
):
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_model_apply_lora_from_file(
|
||||||
|
self.model,
|
||||||
|
lora_path.encode("utf-8"),
|
||||||
|
scale,
|
||||||
|
path_base_model.encode("utf-8")
|
||||||
|
if path_base_model is not None
|
||||||
|
else llama_cpp.c_char_p(0),
|
||||||
|
n_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vocab
|
||||||
|
|
||||||
|
def token_get_text(self, token: int) -> str:
|
||||||
|
# TODO: Fix
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
|
||||||
|
|
||||||
|
def token_get_score(self, token: int) -> float:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_get_score(self.model, token)
|
||||||
|
|
||||||
|
def token_get_type(self, token: int) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_get_type(self.model, token)
|
||||||
|
|
||||||
|
# Special tokens
|
||||||
|
|
||||||
|
def token_bos(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_bos(self.model)
|
||||||
|
|
||||||
|
def token_eos(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_eos(self.model)
|
||||||
|
|
||||||
|
def token_nl(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_nl(self.model)
|
||||||
|
|
||||||
|
def token_prefix(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_prefix(self.model)
|
||||||
|
|
||||||
|
def token_middle(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_middle(self.model)
|
||||||
|
|
||||||
|
def token_suffix(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_suffix(self.model)
|
||||||
|
|
||||||
|
def token_eot(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_eot(self.model)
|
||||||
|
|
||||||
|
# Tokenization
|
||||||
|
|
||||||
|
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
||||||
|
assert self.model is not None
|
||||||
|
n_ctx = self.n_ctx_train()
|
||||||
|
tokens = (llama_cpp.llama_token * n_ctx)()
|
||||||
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
|
self.model, text, len(text), tokens, n_ctx, add_bos, special
|
||||||
|
)
|
||||||
|
if n_tokens < 0:
|
||||||
|
n_tokens = abs(n_tokens)
|
||||||
|
tokens = (llama_cpp.llama_token * n_tokens)()
|
||||||
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
|
self.model, text, len(text), tokens, n_tokens, add_bos, special
|
||||||
|
)
|
||||||
|
if n_tokens < 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
||||||
|
)
|
||||||
|
return list(tokens[:n_tokens])
|
||||||
|
|
||||||
|
def token_to_piece(self, token: int) -> bytes:
|
||||||
|
assert self.model is not None
|
||||||
|
buf = ctypes.create_string_buffer(32)
|
||||||
|
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
|
||||||
|
return bytes(buf)
|
||||||
|
|
||||||
|
def detokenize(self, tokens: List[int]) -> bytes:
|
||||||
|
assert self.model is not None
|
||||||
|
output = b""
|
||||||
|
size = 32
|
||||||
|
buffer = (ctypes.c_char * size)()
|
||||||
|
for token in tokens:
|
||||||
|
n = llama_cpp.llama_token_to_piece(
|
||||||
|
self.model, llama_cpp.llama_token(token), buffer, size
|
||||||
|
)
|
||||||
|
assert n <= size
|
||||||
|
output += bytes(buffer[:n])
|
||||||
|
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
||||||
|
# this line removes a leading space if the first token is a beginning of sentence token
|
||||||
|
return (
|
||||||
|
output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extra
|
||||||
|
def metadata(self) -> Dict[str, str]:
|
||||||
|
assert self.model is not None
|
||||||
|
metadata: Dict[str, str] = {}
|
||||||
|
buffer_size = 1024
|
||||||
|
buffer = ctypes.create_string_buffer(buffer_size)
|
||||||
|
# zero the buffer
|
||||||
|
buffer.value = b'\0' * buffer_size
|
||||||
|
# iterate over model keys
|
||||||
|
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
||||||
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||||
|
if nbytes > buffer_size:
|
||||||
|
buffer_size = nbytes
|
||||||
|
buffer = ctypes.create_string_buffer(buffer_size)
|
||||||
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||||
|
key = buffer.value.decode("utf-8")
|
||||||
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||||
|
if nbytes > buffer_size:
|
||||||
|
buffer_size = nbytes
|
||||||
|
buffer = ctypes.create_string_buffer(buffer_size)
|
||||||
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||||
|
value = buffer.value.decode("utf-8")
|
||||||
|
metadata[key] = value
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_params():
|
||||||
|
"""Get the default llama_model_params."""
|
||||||
|
return llama_cpp.llama_model_default_params()
|
||||||
|
|
||||||
|
|
||||||
|
class _LlamaContext:
|
||||||
|
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
||||||
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||||
|
|
||||||
|
_llama_free = None
|
||||||
|
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||||
|
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: _LlamaModel,
|
||||||
|
params: llama_cpp.llama_context_params,
|
||||||
|
verbose: bool = True,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.params = params
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
self._llama_free = llama_cpp._lib.llama_free # type: ignore
|
||||||
|
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
self.ctx = llama_cpp.llama_new_context_with_model(
|
||||||
|
self.model.model, self.params
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
if self.ctx is not None and self._llama_free is not None:
|
||||||
|
self._llama_free(self.ctx)
|
||||||
|
self.ctx = None
|
||||||
|
|
||||||
|
def n_ctx(self) -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_n_ctx(self.ctx)
|
||||||
|
|
||||||
|
def kv_cache_clear(self):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_clear(self.ctx)
|
||||||
|
|
||||||
|
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
|
||||||
|
|
||||||
|
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
|
||||||
|
|
||||||
|
def kv_cache_seq_keep(self, seq_id: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
|
||||||
|
|
||||||
|
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift)
|
||||||
|
|
||||||
|
def get_state_size(self) -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_get_state_size(self.ctx)
|
||||||
|
|
||||||
|
# TODO: copy_state_data
|
||||||
|
|
||||||
|
# TODO: set_state_data
|
||||||
|
|
||||||
|
# TODO: llama_load_session_file
|
||||||
|
|
||||||
|
# TODO: llama_save_session_file
|
||||||
|
|
||||||
|
def decode(self, batch: "_LlamaBatch"):
|
||||||
|
assert self.ctx is not None
|
||||||
|
assert batch.batch is not None
|
||||||
|
return_code = llama_cpp.llama_decode(
|
||||||
|
ctx=self.ctx,
|
||||||
|
batch=batch.batch,
|
||||||
|
)
|
||||||
|
if return_code != 0:
|
||||||
|
raise RuntimeError(f"llama_decode returned {return_code}")
|
||||||
|
|
||||||
|
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
||||||
|
|
||||||
|
def get_logits(self):
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_get_logits(self.ctx)
|
||||||
|
|
||||||
|
def get_logits_ith(self, i: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
||||||
|
|
||||||
|
def get_embeddings(self):
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_get_embeddings(self.ctx)
|
||||||
|
|
||||||
|
# Sampling functions
|
||||||
|
|
||||||
|
def set_rng_seed(self, seed: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_set_rng_seed(self.ctx, seed)
|
||||||
|
|
||||||
|
def sample_repetition_penalties(
|
||||||
|
self,
|
||||||
|
candidates: "_LlamaTokenDataArray",
|
||||||
|
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
||||||
|
penalty_last_n: int,
|
||||||
|
penalty_repeat: float,
|
||||||
|
penalty_freq: float,
|
||||||
|
penalty_present: float,
|
||||||
|
):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_repetition_penalties(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
last_tokens_data,
|
||||||
|
penalty_last_n,
|
||||||
|
penalty_repeat,
|
||||||
|
penalty_freq,
|
||||||
|
penalty_present,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_classifier_free_guidance(
|
||||||
|
self,
|
||||||
|
candidates: "_LlamaTokenDataArray",
|
||||||
|
guidance_ctx: "_LlamaContext",
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
assert self.ctx is not None
|
||||||
|
assert guidance_ctx.ctx is not None
|
||||||
|
llama_cpp.llama_sample_classifier_free_guidance(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
guidance_ctx.ctx,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_softmax(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_top_k(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_top_p(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_min_p(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_tail_free(
|
||||||
|
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
|
||||||
|
):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_tail_free(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_typical(
|
||||||
|
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
||||||
|
):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_typical(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_sample_temp(
|
||||||
|
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
||||||
|
assert self.ctx is not None
|
||||||
|
assert grammar.grammar is not None
|
||||||
|
llama_cpp.llama_sample_grammar(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
grammar.grammar,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_token_mirostat(
|
||||||
|
self,
|
||||||
|
candidates: "_LlamaTokenDataArray",
|
||||||
|
tau: float,
|
||||||
|
eta: float,
|
||||||
|
m: int,
|
||||||
|
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
|
||||||
|
) -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_sample_token_mirostat(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
tau,
|
||||||
|
eta,
|
||||||
|
m,
|
||||||
|
mu,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_token_mirostat_v2(
|
||||||
|
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
|
||||||
|
) -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_sample_token_mirostat_v2(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
tau,
|
||||||
|
eta,
|
||||||
|
mu,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_sample_token(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.byref(candidates.candidates), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grammar
|
||||||
|
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
||||||
|
assert self.ctx is not None
|
||||||
|
assert grammar.grammar is not None
|
||||||
|
llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token)
|
||||||
|
|
||||||
|
def reset_timings(self):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
|
def print_timings(self):
|
||||||
|
assert self.ctx is not None
|
||||||
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
|
||||||
|
# Utility functions
|
||||||
|
@staticmethod
|
||||||
|
def default_params():
|
||||||
|
"""Get the default llama_context_params."""
|
||||||
|
return llama_cpp.llama_context_default_params()
|
||||||
|
|
||||||
|
|
||||||
|
class _LlamaBatch:
|
||||||
|
_llama_batch_free = None
|
||||||
|
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||||
|
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
||||||
|
):
|
||||||
|
self.n_tokens = n_tokens
|
||||||
|
self.embd = embd
|
||||||
|
self.n_seq_max = n_seq_max
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
|
||||||
|
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
self.batch = llama_cpp.llama_batch_init(
|
||||||
|
self.n_tokens, self.embd, self.n_seq_max
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
if self.batch is not None and self._llama_batch_free is not None:
|
||||||
|
self._llama_batch_free(self.batch)
|
||||||
|
self.batch = None
|
||||||
|
|
||||||
|
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
||||||
|
assert self.batch is not None
|
||||||
|
n_tokens = len(batch)
|
||||||
|
self.batch.n_tokens = n_tokens
|
||||||
|
for i in range(n_tokens):
|
||||||
|
self.batch.token[i] = batch[i]
|
||||||
|
self.batch.pos[i] = n_past + i
|
||||||
|
self.batch.seq_id[i][0] = 0
|
||||||
|
self.batch.n_seq_id[i] = 1
|
||||||
|
self.batch.logits[i] = logits_all
|
||||||
|
self.batch.logits[n_tokens - 1] = True
|
||||||
|
|
||||||
|
|
||||||
|
class _LlamaTokenDataArray:
|
||||||
|
def __init__(self, *, n_vocab: int):
|
||||||
|
self.n_vocab = n_vocab
|
||||||
|
self.candidates_data = np.array(
|
||||||
|
[],
|
||||||
|
dtype=np.dtype(
|
||||||
|
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.candidates_data.resize(3, self.n_vocab, refcheck=False)
|
||||||
|
self.candidates = llama_cpp.llama_token_data_array(
|
||||||
|
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
||||||
|
size=self.n_vocab,
|
||||||
|
sorted=False,
|
||||||
|
)
|
||||||
|
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
|
||||||
|
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
||||||
|
|
||||||
|
def copy_logits(self, logits: npt.NDArray[np.single]):
|
||||||
|
self.candidates_data["id"][:] = self.default_candidates_data_id
|
||||||
|
self.candidates_data["logit"][:] = logits
|
||||||
|
self.candidates_data["p"][:] = self.default_candidates_data_p
|
||||||
|
self.candidates.data = self.candidates_data.ctypes.data_as(
|
||||||
|
llama_cpp.llama_token_data_p
|
||||||
|
)
|
||||||
|
self.candidates.sorted = llama_cpp.c_bool(False)
|
||||||
|
self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
# Python wrappers over common/common
|
||||||
|
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
|
||||||
|
n_tokens = len(text) + 1 if add_bos else len(text)
|
||||||
|
result = (llama_cpp.llama_token * n_tokens)()
|
||||||
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
|
model.model,
|
||||||
|
text.encode("utf-8"),
|
||||||
|
len(text),
|
||||||
|
result,
|
||||||
|
n_tokens,
|
||||||
|
add_bos,
|
||||||
|
special,
|
||||||
|
)
|
||||||
|
if n_tokens < 0:
|
||||||
|
result = (llama_cpp.llama_token * -n_tokens)()
|
||||||
|
check = llama_cpp.llama_tokenize(
|
||||||
|
model.model,
|
||||||
|
text.encode("utf-8"),
|
||||||
|
len(text),
|
||||||
|
result,
|
||||||
|
len(result),
|
||||||
|
add_bos,
|
||||||
|
special,
|
||||||
|
)
|
||||||
|
if check != -n_tokens:
|
||||||
|
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
|
||||||
|
else:
|
||||||
|
result = result[:n_tokens]
|
||||||
|
return list(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _token_to_piece(model: _LlamaModel, token: int) -> str:
|
||||||
|
assert model.model is not None
|
||||||
|
result = (ctypes.c_char * 8)(0)
|
||||||
|
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
||||||
|
if n_tokens < 0:
|
||||||
|
result = (ctypes.c_char * -n_tokens)(0)
|
||||||
|
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
||||||
|
if check != -n_tokens:
|
||||||
|
raise RuntimeError(f"Failed to get piece: token={token}")
|
||||||
|
else:
|
||||||
|
result = result[:n_tokens]
|
||||||
|
return bytes(result).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str:
|
||||||
|
bos_id = model.token_bos()
|
||||||
|
result = ""
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
piece = _token_to_piece(model, token)
|
||||||
|
if (
|
||||||
|
(tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0)
|
||||||
|
) and piece[0] == " ":
|
||||||
|
piece = piece[1:]
|
||||||
|
result += piece
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str:
|
||||||
|
result = ""
|
||||||
|
for token in tokens:
|
||||||
|
piece = _token_to_piece(model, token)
|
||||||
|
result += piece
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _should_add_bos(model: _LlamaModel) -> bool:
|
||||||
|
assert model.model is not None
|
||||||
|
add_bos = llama_cpp.llama_add_bos_token(model.model)
|
||||||
|
if add_bos != -1:
|
||||||
|
return add_bos != 0
|
||||||
|
else:
|
||||||
|
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
|
||||||
|
|
||||||
|
|
||||||
|
# Python wrappers over common/sampling structs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LlamaSamplingParams:
|
||||||
|
n_prev: int = 64
|
||||||
|
n_probs: int = 0
|
||||||
|
top_k: int = 40
|
||||||
|
top_p: float = 0.95
|
||||||
|
min_p: float = 0.05
|
||||||
|
tfs_z: float = 1.00
|
||||||
|
typical_p: float = 1.00
|
||||||
|
temp: float = 0.80
|
||||||
|
penalty_last_n: int = 64
|
||||||
|
penalty_repeat: float = 1.10
|
||||||
|
penalty_freq: float = 0.00
|
||||||
|
penalty_present: float = 0.00
|
||||||
|
mirostat: int = 0
|
||||||
|
mirostat_tau: float = 5.00
|
||||||
|
mirostat_eta: float = 0.10
|
||||||
|
penalize_nl: bool = True
|
||||||
|
|
||||||
|
grammar: str = ""
|
||||||
|
|
||||||
|
cfg_negative_prompt: str = ""
|
||||||
|
cfg_scale: float = 1.00
|
||||||
|
|
||||||
|
logit_bias: dict[int, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LlamaSamplingContext:
|
||||||
|
params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams)
|
||||||
|
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
|
||||||
|
grammar: Optional[LlamaGrammar] = None
|
||||||
|
# NOTE: Missing parsed_grammar
|
||||||
|
prev: list[int] = field(default_factory=list)
|
||||||
|
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.prev = []
|
||||||
|
self.cur = []
|
||||||
|
if self.grammar is not None:
|
||||||
|
self.grammar.reset()
|
||||||
|
|
||||||
|
def cp(self):
|
||||||
|
return _LlamaSamplingContext(
|
||||||
|
params=self.params,
|
||||||
|
mirostat_mu=self.mirostat_mu,
|
||||||
|
grammar=self.grammar,
|
||||||
|
prev=self.prev.copy(),
|
||||||
|
cur=self.cur.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def last(self) -> Optional[int]:
|
||||||
|
if len(self.prev) > 0:
|
||||||
|
return self.prev[-1]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prev_str(self, ctx_main: _LlamaContext, n: int) -> str:
|
||||||
|
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
|
||||||
|
):
|
||||||
|
n_vocab = ctx_main.model.n_vocab()
|
||||||
|
id: int = 0
|
||||||
|
|
||||||
|
if logits_array is None:
|
||||||
|
logits = ctx_main.get_logits_ith(idx)
|
||||||
|
logits_array = np.array(
|
||||||
|
ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents,
|
||||||
|
dtype=np.single,
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply logit_bias
|
||||||
|
for token, logit_bias in self.params.logit_bias.items():
|
||||||
|
logits_array[token] += logit_bias
|
||||||
|
|
||||||
|
token_data_array = _LlamaTokenDataArray(
|
||||||
|
n_vocab=n_vocab
|
||||||
|
) # TODO: Only create this once
|
||||||
|
token_data_array.copy_logits(logits_array)
|
||||||
|
|
||||||
|
if ctx_cfg is not None:
|
||||||
|
ctx_main.sample_classifier_free_guidance(
|
||||||
|
token_data_array, ctx_cfg, self.params.cfg_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply penalties
|
||||||
|
if len(self.prev) > 0:
|
||||||
|
nl_token = ctx_main.model.token_nl()
|
||||||
|
nl_logit = logits_array[nl_token]
|
||||||
|
if self.params.penalty_last_n > 0:
|
||||||
|
ctx_main.sample_repetition_penalties(
|
||||||
|
token_data_array,
|
||||||
|
# TODO: Only create this once
|
||||||
|
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
|
||||||
|
self.params.penalty_last_n,
|
||||||
|
self.params.penalty_repeat,
|
||||||
|
self.params.penalty_freq,
|
||||||
|
self.params.penalty_present,
|
||||||
|
)
|
||||||
|
if not self.params.penalize_nl:
|
||||||
|
token_data_array.candidates_data["logit"][nl_token] = nl_logit
|
||||||
|
|
||||||
|
if self.grammar is not None:
|
||||||
|
ctx_main.sample_grammar(token_data_array, self.grammar)
|
||||||
|
|
||||||
|
if self.params.temp < 0:
|
||||||
|
ctx_main.sample_softmax(token_data_array)
|
||||||
|
id = token_data_array.candidates_data["id"][0]
|
||||||
|
elif self.params.temp == 0:
|
||||||
|
id = ctx_main.sample_token_greedy(token_data_array)
|
||||||
|
else:
|
||||||
|
if self.params.mirostat == 1:
|
||||||
|
mirostat_m = 100
|
||||||
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
||||||
|
id = ctx_main.sample_token_mirostat(
|
||||||
|
token_data_array,
|
||||||
|
self.params.mirostat_tau,
|
||||||
|
self.params.mirostat_eta,
|
||||||
|
mirostat_m,
|
||||||
|
ctypes.pointer(self.mirostat_mu),
|
||||||
|
)
|
||||||
|
elif self.params.mirostat == 2:
|
||||||
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
||||||
|
id = ctx_main.sample_token_mirostat_v2(
|
||||||
|
token_data_array,
|
||||||
|
self.params.mirostat_tau,
|
||||||
|
self.params.mirostat_eta,
|
||||||
|
ctypes.pointer(self.mirostat_mu),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
min_keep = max(1, self.params.n_probs)
|
||||||
|
ctx_main.sample_top_k(
|
||||||
|
token_data_array, self.params.top_k, min_keep=min_keep
|
||||||
|
)
|
||||||
|
ctx_main.sample_tail_free(
|
||||||
|
token_data_array, self.params.tfs_z, min_keep=min_keep
|
||||||
|
)
|
||||||
|
ctx_main.sample_typical(
|
||||||
|
token_data_array, self.params.typical_p, min_keep=min_keep
|
||||||
|
)
|
||||||
|
ctx_main.sample_top_p(
|
||||||
|
token_data_array, self.params.top_p, min_keep=min_keep
|
||||||
|
)
|
||||||
|
ctx_main.sample_min_p(
|
||||||
|
token_data_array, self.params.min_p, min_keep=min_keep
|
||||||
|
)
|
||||||
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
||||||
|
id = ctx_main.sample_token(token_data_array)
|
||||||
|
return id
|
||||||
|
|
||||||
|
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
|
||||||
|
if apply_grammar and self.grammar is not None:
|
||||||
|
ctx_main.grammar_accept_token(self.grammar, id)
|
||||||
|
self.prev.append(id)
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import sys, traceback
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
||||||
outnull_file = open(os.devnull, "w")
|
outnull_file = open(os.devnull, "w")
|
||||||
|
@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
|
||||||
|
|
||||||
self.os.close(self.old_stdout_fileno)
|
self.os.close(self.old_stdout_fileno)
|
||||||
self.os.close(self.old_stderr_fileno)
|
self.os.close(self.old_stderr_fileno)
|
||||||
|
|
||||||
|
|
||||||
|
class MetaSingleton(type):
|
||||||
|
"""
|
||||||
|
Metaclass for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: Dict[type, Any] = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton(object, metaclass=MetaSingleton):
|
||||||
|
"""
|
||||||
|
Base class for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Singleton, self).__init__()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -12,16 +13,20 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
Iterator,
|
Iterator,
|
||||||
Deque,
|
Deque,
|
||||||
Tuple,
|
|
||||||
Callable,
|
Callable,
|
||||||
)
|
)
|
||||||
from collections import deque, OrderedDict
|
from collections import deque
|
||||||
|
|
||||||
import diskcache
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
from .llama_types import *
|
from .llama_types import *
|
||||||
from .llama_grammar import LlamaGrammar
|
from .llama_grammar import LlamaGrammar
|
||||||
|
from .llama_cache import (
|
||||||
|
BaseLlamaCache,
|
||||||
|
LlamaCache, # type: ignore
|
||||||
|
LlamaDiskCache, # type: ignore
|
||||||
|
LlamaRAMCache, # type: ignore
|
||||||
|
)
|
||||||
import llama_cpp.llama_cpp as llama_cpp
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
import llama_cpp.llama_chat_format as llama_chat_format
|
import llama_cpp.llama_chat_format as llama_chat_format
|
||||||
|
|
||||||
|
@ -29,694 +34,12 @@ import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
from ._utils import suppress_stdout_stderr
|
from ._utils import suppress_stdout_stderr
|
||||||
|
from ._internals import (
|
||||||
|
_LlamaModel, # type: ignore
|
||||||
class BaseLlamaCache(ABC):
|
_LlamaContext, # type: ignore
|
||||||
"""Base cache class for a llama.cpp model."""
|
_LlamaBatch, # type: ignore
|
||||||
|
_LlamaTokenDataArray, # type: ignore
|
||||||
def __init__(self, capacity_bytes: int = (2 << 30)):
|
)
|
||||||
self.capacity_bytes = capacity_bytes
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def cache_size(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _find_longest_prefix_key(
|
|
||||||
self,
|
|
||||||
key: Tuple[int, ...],
|
|
||||||
) -> Optional[Tuple[int, ...]]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __contains__(self, key: Sequence[int]) -> bool:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRAMCache(BaseLlamaCache):
|
|
||||||
"""Cache for a llama.cpp model using RAM."""
|
|
||||||
|
|
||||||
def __init__(self, capacity_bytes: int = (2 << 30)):
|
|
||||||
super().__init__(capacity_bytes)
|
|
||||||
self.capacity_bytes = capacity_bytes
|
|
||||||
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_size(self):
|
|
||||||
return sum([state.llama_state_size for state in self.cache_state.values()])
|
|
||||||
|
|
||||||
def _find_longest_prefix_key(
|
|
||||||
self,
|
|
||||||
key: Tuple[int, ...],
|
|
||||||
) -> Optional[Tuple[int, ...]]:
|
|
||||||
min_len = 0
|
|
||||||
min_key = None
|
|
||||||
keys = (
|
|
||||||
(k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
|
|
||||||
)
|
|
||||||
for k, prefix_len in keys:
|
|
||||||
if prefix_len > min_len:
|
|
||||||
min_len = prefix_len
|
|
||||||
min_key = k
|
|
||||||
return min_key
|
|
||||||
|
|
||||||
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
|
|
||||||
key = tuple(key)
|
|
||||||
_key = self._find_longest_prefix_key(key)
|
|
||||||
if _key is None:
|
|
||||||
raise KeyError("Key not found")
|
|
||||||
value = self.cache_state[_key]
|
|
||||||
self.cache_state.move_to_end(_key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __contains__(self, key: Sequence[int]) -> bool:
|
|
||||||
return self._find_longest_prefix_key(tuple(key)) is not None
|
|
||||||
|
|
||||||
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
|
|
||||||
key = tuple(key)
|
|
||||||
if key in self.cache_state:
|
|
||||||
del self.cache_state[key]
|
|
||||||
self.cache_state[key] = value
|
|
||||||
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
|
|
||||||
self.cache_state.popitem(last=False)
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for backwards compatibility
|
|
||||||
LlamaCache = LlamaRAMCache
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDiskCache(BaseLlamaCache):
|
|
||||||
"""Cache for a llama.cpp model using disk."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
|
|
||||||
):
|
|
||||||
super().__init__(capacity_bytes)
|
|
||||||
self.cache = diskcache.Cache(cache_dir)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_size(self):
|
|
||||||
return int(self.cache.volume()) # type: ignore
|
|
||||||
|
|
||||||
def _find_longest_prefix_key(
|
|
||||||
self,
|
|
||||||
key: Tuple[int, ...],
|
|
||||||
) -> Optional[Tuple[int, ...]]:
|
|
||||||
min_len = 0
|
|
||||||
min_key: Optional[Tuple[int, ...]] = None
|
|
||||||
for k in self.cache.iterkeys(): # type: ignore
|
|
||||||
prefix_len = Llama.longest_token_prefix(k, key)
|
|
||||||
if prefix_len > min_len:
|
|
||||||
min_len = prefix_len
|
|
||||||
min_key = k # type: ignore
|
|
||||||
return min_key
|
|
||||||
|
|
||||||
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
|
|
||||||
key = tuple(key)
|
|
||||||
_key = self._find_longest_prefix_key(key)
|
|
||||||
if _key is None:
|
|
||||||
raise KeyError("Key not found")
|
|
||||||
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
|
||||||
# NOTE: This puts an integer as key in cache, which breaks,
|
|
||||||
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
|
||||||
# self.cache.push(_key, side="front") # type: ignore
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __contains__(self, key: Sequence[int]) -> bool:
|
|
||||||
return self._find_longest_prefix_key(tuple(key)) is not None
|
|
||||||
|
|
||||||
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
|
|
||||||
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
|
|
||||||
key = tuple(key)
|
|
||||||
if key in self.cache:
|
|
||||||
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
|
|
||||||
del self.cache[key]
|
|
||||||
self.cache[key] = value
|
|
||||||
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
|
|
||||||
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
|
|
||||||
key_to_remove = next(iter(self.cache))
|
|
||||||
del self.cache[key_to_remove]
|
|
||||||
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaState:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_ids: npt.NDArray[np.intc],
|
|
||||||
scores: npt.NDArray[np.single],
|
|
||||||
n_tokens: int,
|
|
||||||
llama_state: bytes,
|
|
||||||
llama_state_size: int,
|
|
||||||
):
|
|
||||||
self.input_ids = input_ids
|
|
||||||
self.scores = scores
|
|
||||||
self.n_tokens = n_tokens
|
|
||||||
self.llama_state = llama_state
|
|
||||||
self.llama_state_size = llama_state_size
|
|
||||||
|
|
||||||
|
|
||||||
LogitsProcessor = Callable[
|
|
||||||
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessorList(List[LogitsProcessor]):
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
|
||||||
) -> npt.NDArray[np.single]:
|
|
||||||
for processor in self:
|
|
||||||
scores = processor(input_ids, scores)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
|
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaList(List[StoppingCriteria]):
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
|
||||||
) -> bool:
|
|
||||||
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
|
||||||
|
|
||||||
|
|
||||||
class _LlamaModel:
|
|
||||||
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
|
||||||
|
|
||||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
||||||
|
|
||||||
_llama_free_model = None
|
|
||||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
|
||||||
suppress_stdout_stderr = suppress_stdout_stderr
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
path_model: str,
|
|
||||||
params: llama_cpp.llama_model_params,
|
|
||||||
verbose: bool = True,
|
|
||||||
):
|
|
||||||
self.path_model = path_model
|
|
||||||
self.params = params
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
|
|
||||||
|
|
||||||
if not os.path.exists(path_model):
|
|
||||||
raise ValueError(f"Model path does not exist: {path_model}")
|
|
||||||
|
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
self.model = llama_cpp.llama_load_model_from_file(
|
|
||||||
self.path_model.encode("utf-8"), self.params
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
with self.suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
if self.model is not None and self._llama_free_model is not None:
|
|
||||||
self._llama_free_model(self.model)
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
def vocab_type(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_vocab_type(self.model)
|
|
||||||
|
|
||||||
def n_vocab(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_n_vocab(self.model)
|
|
||||||
|
|
||||||
def n_ctx_train(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_n_ctx_train(self.model)
|
|
||||||
|
|
||||||
def n_embd(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_n_embd(self.model)
|
|
||||||
|
|
||||||
def rope_freq_scale_train(self) -> float:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_rope_freq_scale_train(self.model)
|
|
||||||
|
|
||||||
def desc(self) -> str:
|
|
||||||
assert self.model is not None
|
|
||||||
buf = ctypes.create_string_buffer(1024)
|
|
||||||
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
|
|
||||||
return buf.value.decode("utf-8")
|
|
||||||
|
|
||||||
def size(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_model_size(self.model)
|
|
||||||
|
|
||||||
def n_params(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_model_n_params(self.model)
|
|
||||||
|
|
||||||
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
|
|
||||||
|
|
||||||
def apply_lora_from_file(
|
|
||||||
self,
|
|
||||||
lora_path: str,
|
|
||||||
scale: float,
|
|
||||||
path_base_model: Optional[str],
|
|
||||||
n_threads: int,
|
|
||||||
):
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_model_apply_lora_from_file(
|
|
||||||
self.model,
|
|
||||||
lora_path.encode("utf-8"),
|
|
||||||
scale,
|
|
||||||
path_base_model.encode("utf-8")
|
|
||||||
if path_base_model is not None
|
|
||||||
else llama_cpp.c_char_p(0),
|
|
||||||
n_threads,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Vocab
|
|
||||||
|
|
||||||
def token_get_text(self, token: int) -> str:
|
|
||||||
# TODO: Fix
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
|
|
||||||
|
|
||||||
def token_get_score(self, token: int) -> float:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_get_score(self.model, token)
|
|
||||||
|
|
||||||
def token_get_type(self, token: int) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_get_type(self.model, token)
|
|
||||||
|
|
||||||
# Special tokens
|
|
||||||
|
|
||||||
def token_bos(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_bos(self.model)
|
|
||||||
|
|
||||||
def token_eos(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_eos(self.model)
|
|
||||||
|
|
||||||
def token_nl(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_nl(self.model)
|
|
||||||
|
|
||||||
def token_prefix(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_prefix(self.model)
|
|
||||||
|
|
||||||
def token_middle(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_middle(self.model)
|
|
||||||
|
|
||||||
def token_suffix(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_suffix(self.model)
|
|
||||||
|
|
||||||
def token_eot(self) -> int:
|
|
||||||
assert self.model is not None
|
|
||||||
return llama_cpp.llama_token_eot(self.model)
|
|
||||||
|
|
||||||
# Tokenization
|
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
|
||||||
assert self.model is not None
|
|
||||||
n_ctx = self.n_ctx_train()
|
|
||||||
tokens = (llama_cpp.llama_token * n_ctx)()
|
|
||||||
n_tokens = llama_cpp.llama_tokenize(
|
|
||||||
self.model, text, len(text), tokens, n_ctx, add_bos, special
|
|
||||||
)
|
|
||||||
if n_tokens < 0:
|
|
||||||
n_tokens = abs(n_tokens)
|
|
||||||
tokens = (llama_cpp.llama_token * n_tokens)()
|
|
||||||
n_tokens = llama_cpp.llama_tokenize(
|
|
||||||
self.model, text, len(text), tokens, n_tokens, add_bos, special
|
|
||||||
)
|
|
||||||
if n_tokens < 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
|
||||||
)
|
|
||||||
return list(tokens[:n_tokens])
|
|
||||||
|
|
||||||
def token_to_piece(self, token: int) -> bytes:
|
|
||||||
assert self.model is not None
|
|
||||||
buf = ctypes.create_string_buffer(32)
|
|
||||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
|
|
||||||
return bytes(buf)
|
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int]) -> bytes:
|
|
||||||
assert self.model is not None
|
|
||||||
output = b""
|
|
||||||
size = 32
|
|
||||||
buffer = (ctypes.c_char * size)()
|
|
||||||
for token in tokens:
|
|
||||||
n = llama_cpp.llama_token_to_piece(
|
|
||||||
self.model, llama_cpp.llama_token(token), buffer, size
|
|
||||||
)
|
|
||||||
assert n <= size
|
|
||||||
output += bytes(buffer[:n])
|
|
||||||
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
|
||||||
# this line removes a leading space if the first token is a beginning of sentence token
|
|
||||||
return (
|
|
||||||
output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def default_params():
|
|
||||||
"""Get the default llama_model_params."""
|
|
||||||
return llama_cpp.llama_model_default_params()
|
|
||||||
|
|
||||||
|
|
||||||
class _LlamaContext:
|
|
||||||
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
|
||||||
|
|
||||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
||||||
|
|
||||||
_llama_free = None
|
|
||||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
|
||||||
suppress_stdout_stderr = suppress_stdout_stderr
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: _LlamaModel,
|
|
||||||
params: llama_cpp.llama_context_params,
|
|
||||||
verbose: bool = True,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.params = params
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
self._llama_free = llama_cpp._lib.llama_free # type: ignore
|
|
||||||
|
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
self.ctx = llama_cpp.llama_new_context_with_model(
|
|
||||||
self.model.model, self.params
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
with self.suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
if self.ctx is not None and self._llama_free is not None:
|
|
||||||
self._llama_free(self.ctx)
|
|
||||||
self.ctx = None
|
|
||||||
|
|
||||||
def n_ctx(self) -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_n_ctx(self.ctx)
|
|
||||||
|
|
||||||
def kv_cache_clear(self):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_kv_cache_clear(self.ctx)
|
|
||||||
|
|
||||||
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
|
|
||||||
|
|
||||||
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
|
|
||||||
|
|
||||||
def kv_cache_seq_keep(self, seq_id: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
|
|
||||||
|
|
||||||
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift)
|
|
||||||
|
|
||||||
def get_state_size(self) -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_get_state_size(self.ctx)
|
|
||||||
|
|
||||||
# TODO: copy_state_data
|
|
||||||
|
|
||||||
# TODO: set_state_data
|
|
||||||
|
|
||||||
# TODO: llama_load_session_file
|
|
||||||
|
|
||||||
# TODO: llama_save_session_file
|
|
||||||
|
|
||||||
def decode(self, batch: "_LlamaBatch"):
|
|
||||||
assert self.ctx is not None
|
|
||||||
assert batch.batch is not None
|
|
||||||
return_code = llama_cpp.llama_decode(
|
|
||||||
ctx=self.ctx,
|
|
||||||
batch=batch.batch,
|
|
||||||
)
|
|
||||||
if return_code != 0:
|
|
||||||
raise RuntimeError(f"llama_decode returned {return_code}")
|
|
||||||
|
|
||||||
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
|
||||||
|
|
||||||
def get_logits(self):
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_get_logits(self.ctx)
|
|
||||||
|
|
||||||
def get_logits_ith(self, i: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
|
||||||
|
|
||||||
def get_embeddings(self):
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_get_embeddings(self.ctx)
|
|
||||||
|
|
||||||
# Sampling functions
|
|
||||||
|
|
||||||
def set_rng_seed(self, seed: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_set_rng_seed(self.ctx, seed)
|
|
||||||
|
|
||||||
def sample_repetition_penalties(
|
|
||||||
self,
|
|
||||||
candidates: "_LlamaTokenDataArray",
|
|
||||||
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
|
||||||
penalty_last_n: int,
|
|
||||||
penalty_repeat: float,
|
|
||||||
penalty_freq: float,
|
|
||||||
penalty_present: float,
|
|
||||||
):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_repetition_penalties(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
last_tokens_data,
|
|
||||||
penalty_last_n,
|
|
||||||
penalty_repeat,
|
|
||||||
penalty_freq,
|
|
||||||
penalty_present,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_classifier_free_guidance(
|
|
||||||
self,
|
|
||||||
candidates: "_LlamaTokenDataArray",
|
|
||||||
guidance_ctx: "_LlamaContext",
|
|
||||||
scale: float,
|
|
||||||
):
|
|
||||||
assert self.ctx is not None
|
|
||||||
assert guidance_ctx.ctx is not None
|
|
||||||
llama_cpp.llama_sample_classifier_free_guidance(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
guidance_ctx.ctx,
|
|
||||||
scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_softmax(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_top_k(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_top_p(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_min_p(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_tail_free(
|
|
||||||
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
|
|
||||||
):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_tail_free(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_typical(
|
|
||||||
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
|
||||||
):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_typical(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_sample_temp(
|
|
||||||
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
|
||||||
assert self.ctx is not None
|
|
||||||
assert grammar.grammar is not None
|
|
||||||
llama_cpp.llama_sample_grammar(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
grammar.grammar,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_token_mirostat(
|
|
||||||
self,
|
|
||||||
candidates: "_LlamaTokenDataArray",
|
|
||||||
tau: float,
|
|
||||||
eta: float,
|
|
||||||
m: int,
|
|
||||||
mu: float,
|
|
||||||
) -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_sample_token_mirostat(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
tau,
|
|
||||||
eta,
|
|
||||||
m,
|
|
||||||
ctypes.pointer(ctypes.c_float(mu)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_token_mirostat_v2(
|
|
||||||
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float
|
|
||||||
) -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_sample_token_mirostat_v2(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
tau,
|
|
||||||
eta,
|
|
||||||
ctypes.pointer(ctypes.c_float(mu)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
||||||
assert self.ctx is not None
|
|
||||||
return llama_cpp.llama_sample_token(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
# Grammar
|
|
||||||
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
|
||||||
assert self.ctx is not None
|
|
||||||
assert grammar.grammar is not None
|
|
||||||
llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token)
|
|
||||||
|
|
||||||
def reset_timings(self):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
|
||||||
|
|
||||||
def print_timings(self):
|
|
||||||
assert self.ctx is not None
|
|
||||||
llama_cpp.llama_print_timings(self.ctx)
|
|
||||||
|
|
||||||
# Utility functions
|
|
||||||
@staticmethod
|
|
||||||
def default_params():
|
|
||||||
"""Get the default llama_context_params."""
|
|
||||||
return llama_cpp.llama_context_default_params()
|
|
||||||
|
|
||||||
|
|
||||||
class _LlamaBatch:
|
|
||||||
_llama_batch_free = None
|
|
||||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
|
||||||
suppress_stdout_stderr = suppress_stdout_stderr
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
|
||||||
):
|
|
||||||
self.n_tokens = n_tokens
|
|
||||||
self.embd = embd
|
|
||||||
self.n_seq_max = n_seq_max
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
|
|
||||||
|
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
self.batch = llama_cpp.llama_batch_init(
|
|
||||||
self.n_tokens, self.embd, self.n_seq_max
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
with self.suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
if self.batch is not None and self._llama_batch_free is not None:
|
|
||||||
self._llama_batch_free(self.batch)
|
|
||||||
self.batch = None
|
|
||||||
|
|
||||||
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
|
||||||
assert self.batch is not None
|
|
||||||
n_tokens = len(batch)
|
|
||||||
self.batch.n_tokens = n_tokens
|
|
||||||
for i in range(n_tokens):
|
|
||||||
self.batch.token[i] = batch[i]
|
|
||||||
self.batch.pos[i] = n_past + i
|
|
||||||
self.batch.seq_id[i][0] = 0
|
|
||||||
self.batch.n_seq_id[i] = 1
|
|
||||||
self.batch.logits[i] = logits_all
|
|
||||||
self.batch.logits[n_tokens - 1] = True
|
|
||||||
|
|
||||||
|
|
||||||
class _LlamaTokenDataArray:
|
|
||||||
def __init__(self, *, n_vocab: int):
|
|
||||||
self.n_vocab = n_vocab
|
|
||||||
self.candidates_data = np.array(
|
|
||||||
[],
|
|
||||||
dtype=np.dtype(
|
|
||||||
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.candidates_data.resize(3, self.n_vocab, refcheck=False)
|
|
||||||
self.candidates = llama_cpp.llama_token_data_array(
|
|
||||||
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
|
||||||
size=self.n_vocab,
|
|
||||||
sorted=False,
|
|
||||||
)
|
|
||||||
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
|
|
||||||
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
|
||||||
|
|
||||||
def copy_logits(self, logits: npt.NDArray[np.single]):
|
|
||||||
self.candidates_data["id"][:] = self.default_candidates_data_id
|
|
||||||
self.candidates_data["logit"][:] = logits
|
|
||||||
self.candidates_data["p"][:] = self.default_candidates_data_p
|
|
||||||
self.candidates.data = self.candidates_data.ctypes.data_as(
|
|
||||||
llama_cpp.llama_token_data_p
|
|
||||||
)
|
|
||||||
self.candidates.sorted = llama_cpp.c_bool(False)
|
|
||||||
self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
|
@ -754,7 +77,7 @@ class Llama:
|
||||||
mul_mat_q: bool = True,
|
mul_mat_q: bool = True,
|
||||||
logits_all: bool = False,
|
logits_all: bool = False,
|
||||||
embedding: bool = False,
|
embedding: bool = False,
|
||||||
offload_kqv: bool = False,
|
offload_kqv: bool = True,
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
last_n_tokens_size: int = 64,
|
last_n_tokens_size: int = 64,
|
||||||
# LoRA Params
|
# LoRA Params
|
||||||
|
@ -1006,6 +329,18 @@ class Llama:
|
||||||
(n_ctx, self._n_vocab), dtype=np.single
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.metadata = self._model.metadata()
|
||||||
|
except Exception as e:
|
||||||
|
self.metadata = {}
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Failed to load metadata: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ctx(self) -> llama_cpp.llama_context_p:
|
def ctx(self) -> llama_cpp.llama_context_p:
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
|
@ -1193,7 +528,7 @@ class Llama:
|
||||||
candidates=self._candidates,
|
candidates=self._candidates,
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
eta=mirostat_eta,
|
eta=mirostat_eta,
|
||||||
mu=2.0 * mirostat_tau,
|
mu=ctypes.pointer(self._mirostat_mu),
|
||||||
m=100,
|
m=100,
|
||||||
)
|
)
|
||||||
elif mirostat_mode == 2:
|
elif mirostat_mode == 2:
|
||||||
|
@ -1202,7 +537,7 @@ class Llama:
|
||||||
candidates=self._candidates,
|
candidates=self._candidates,
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
eta=mirostat_eta,
|
eta=mirostat_eta,
|
||||||
mu=2.0 * mirostat_tau,
|
mu=ctypes.pointer(self._mirostat_mu)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||||
|
@ -1258,6 +593,10 @@ class Llama:
|
||||||
Yields:
|
Yields:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
|
# Reset mirostat sampling
|
||||||
|
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
|
||||||
|
|
||||||
|
# Check for kv cache prefix match
|
||||||
if reset and self.n_tokens > 0:
|
if reset and self.n_tokens > 0:
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||||
|
@ -1272,12 +611,15 @@ class Llama:
|
||||||
tokens = tokens[longest_prefix:]
|
tokens = tokens[longest_prefix:]
|
||||||
self.n_tokens = longest_prefix
|
self.n_tokens = longest_prefix
|
||||||
|
|
||||||
|
# Reset the model state
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
# Reset the grammar
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
grammar.reset()
|
grammar.reset()
|
||||||
|
|
||||||
|
# Eval and sample
|
||||||
while True:
|
while True:
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
token = self.sample(
|
token = self.sample(
|
||||||
|
@ -2372,3 +1714,43 @@ class LlamaTokenizer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
||||||
return cls(Llama(model_path=path, vocab_only=True))
|
return cls(Llama(model_path=path, vocab_only=True))
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaState:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_ids: npt.NDArray[np.intc],
|
||||||
|
scores: npt.NDArray[np.single],
|
||||||
|
n_tokens: int,
|
||||||
|
llama_state: bytes,
|
||||||
|
llama_state_size: int,
|
||||||
|
):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.scores = scores
|
||||||
|
self.n_tokens = n_tokens
|
||||||
|
self.llama_state = llama_state
|
||||||
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
|
||||||
|
LogitsProcessor = Callable[
|
||||||
|
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessorList(List[LogitsProcessor]):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
||||||
|
) -> npt.NDArray[np.single]:
|
||||||
|
for processor in self:
|
||||||
|
scores = processor(input_ids, scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteriaList(List[StoppingCriteria]):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
||||||
|
) -> bool:
|
||||||
|
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
||||||
|
|
150
llama_cpp/llama_cache.py
Normal file
150
llama_cpp/llama_cache.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import diskcache
|
||||||
|
|
||||||
|
import llama_cpp.llama
|
||||||
|
|
||||||
|
from .llama_types import *
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLlamaCache(ABC):
|
||||||
|
"""Base cache class for a llama.cpp model."""
|
||||||
|
|
||||||
|
def __init__(self, capacity_bytes: int = (2 << 30)):
|
||||||
|
self.capacity_bytes = capacity_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _find_longest_prefix_key(
|
||||||
|
self,
|
||||||
|
key: Tuple[int, ...],
|
||||||
|
) -> Optional[Tuple[int, ...]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __contains__(self, key: Sequence[int]) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaRAMCache(BaseLlamaCache):
|
||||||
|
"""Cache for a llama.cpp model using RAM."""
|
||||||
|
|
||||||
|
def __init__(self, capacity_bytes: int = (2 << 30)):
|
||||||
|
super().__init__(capacity_bytes)
|
||||||
|
self.capacity_bytes = capacity_bytes
|
||||||
|
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_size(self):
|
||||||
|
return sum([state.llama_state_size for state in self.cache_state.values()])
|
||||||
|
|
||||||
|
def _find_longest_prefix_key(
|
||||||
|
self,
|
||||||
|
key: Tuple[int, ...],
|
||||||
|
) -> Optional[Tuple[int, ...]]:
|
||||||
|
min_len = 0
|
||||||
|
min_key = None
|
||||||
|
keys = (
|
||||||
|
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
|
||||||
|
)
|
||||||
|
for k, prefix_len in keys:
|
||||||
|
if prefix_len > min_len:
|
||||||
|
min_len = prefix_len
|
||||||
|
min_key = k
|
||||||
|
return min_key
|
||||||
|
|
||||||
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||||
|
key = tuple(key)
|
||||||
|
_key = self._find_longest_prefix_key(key)
|
||||||
|
if _key is None:
|
||||||
|
raise KeyError("Key not found")
|
||||||
|
value = self.cache_state[_key]
|
||||||
|
self.cache_state.move_to_end(_key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __contains__(self, key: Sequence[int]) -> bool:
|
||||||
|
return self._find_longest_prefix_key(tuple(key)) is not None
|
||||||
|
|
||||||
|
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
||||||
|
key = tuple(key)
|
||||||
|
if key in self.cache_state:
|
||||||
|
del self.cache_state[key]
|
||||||
|
self.cache_state[key] = value
|
||||||
|
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
|
||||||
|
self.cache_state.popitem(last=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backwards compatibility
|
||||||
|
LlamaCache = LlamaRAMCache
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDiskCache(BaseLlamaCache):
|
||||||
|
"""Cache for a llama.cpp model using disk."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
|
||||||
|
):
|
||||||
|
super().__init__(capacity_bytes)
|
||||||
|
self.cache = diskcache.Cache(cache_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_size(self):
|
||||||
|
return int(self.cache.volume()) # type: ignore
|
||||||
|
|
||||||
|
def _find_longest_prefix_key(
|
||||||
|
self,
|
||||||
|
key: Tuple[int, ...],
|
||||||
|
) -> Optional[Tuple[int, ...]]:
|
||||||
|
min_len = 0
|
||||||
|
min_key: Optional[Tuple[int, ...]] = None
|
||||||
|
for k in self.cache.iterkeys(): # type: ignore
|
||||||
|
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
|
||||||
|
if prefix_len > min_len:
|
||||||
|
min_len = prefix_len
|
||||||
|
min_key = k # type: ignore
|
||||||
|
return min_key
|
||||||
|
|
||||||
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||||
|
key = tuple(key)
|
||||||
|
_key = self._find_longest_prefix_key(key)
|
||||||
|
if _key is None:
|
||||||
|
raise KeyError("Key not found")
|
||||||
|
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
|
||||||
|
# NOTE: This puts an integer as key in cache, which breaks,
|
||||||
|
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
||||||
|
# self.cache.push(_key, side="front") # type: ignore
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __contains__(self, key: Sequence[int]) -> bool:
|
||||||
|
return self._find_longest_prefix_key(tuple(key)) is not None
|
||||||
|
|
||||||
|
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
||||||
|
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
|
||||||
|
key = tuple(key)
|
||||||
|
if key in self.cache:
|
||||||
|
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
|
||||||
|
del self.cache[key]
|
||||||
|
self.cache[key] = value
|
||||||
|
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
|
||||||
|
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
|
||||||
|
key_to_remove = next(iter(self.cache))
|
||||||
|
del self.cache[key_to_remove]
|
||||||
|
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
|
|
@ -6,18 +6,28 @@ import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
import llama_cpp.llama as llama
|
import llama_cpp.llama as llama
|
||||||
import llama_cpp.llama_types as llama_types
|
import llama_cpp.llama_types as llama_types
|
||||||
import llama_cpp.llama_grammar as llama_grammar
|
import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
|
||||||
from ._utils import suppress_stdout_stderr
|
from ._utils import suppress_stdout_stderr, Singleton
|
||||||
|
|
||||||
|
|
||||||
class LlamaChatCompletionHandler(Protocol):
|
class LlamaChatCompletionHandler(Protocol):
|
||||||
|
"""Base Protocol for a llama chat completion handler.
|
||||||
|
|
||||||
|
Very generic protocol that can be used to implement any chat format.
|
||||||
|
The only hard requirement is that it must return a ChatCompletion when
|
||||||
|
stream=False and an iterator of ChatCompletionChunks when stream=True."""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
# llama.cpp instance
|
||||||
llama: llama.Llama,
|
llama: llama.Llama,
|
||||||
|
# openai api parameters
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||||
|
@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
min_p: float = 0.05,
|
|
||||||
typical_p: float = 1.0,
|
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
# llama.cpp parameters
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
tfs_z: float = 1.0,
|
tfs_z: float = 1.0,
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -54,146 +65,78 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
class LlamaChatCompletionHandlerNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChatCompletionHandlerRegistry(Singleton):
|
||||||
|
_chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
|
||||||
|
|
||||||
|
def register_chat_completion_handler(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
chat_handler: LlamaChatCompletionHandler,
|
||||||
|
overwrite: bool = False,
|
||||||
|
):
|
||||||
|
if not overwrite and name in self._chat_handlers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||||
|
)
|
||||||
|
self._chat_handlers[name] = chat_handler
|
||||||
|
|
||||||
|
def unregister_chat_handler(self, name: str):
|
||||||
|
if name in self._chat_handlers:
|
||||||
|
del self._chat_handlers[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||||
|
|
||||||
|
def get_chat_completion_handler_by_name(
|
||||||
|
self, name: str
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
try:
|
||||||
|
chat_handler = self._chat_handlers[name]
|
||||||
|
return chat_handler
|
||||||
|
except KeyError:
|
||||||
|
raise LlamaChatCompletionHandlerNotFoundException(
|
||||||
|
f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
||||||
return CHAT_HANDLERS[name]
|
return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
|
||||||
|
name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_chat_completion_handler(name: str):
|
def register_chat_completion_handler(name: str):
|
||||||
def decorator(f: LlamaChatCompletionHandler):
|
def decorator(f: LlamaChatCompletionHandler):
|
||||||
CHAT_HANDLERS[name] = f
|
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _get_system_message(
|
### Chat Formatter ###
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
||||||
) -> str:
|
|
||||||
"""Get the first system message."""
|
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "system":
|
|
||||||
return message["content"] or ""
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _map_roles(
|
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
|
|
||||||
) -> List[Tuple[str, Optional[str]]]:
|
|
||||||
"""Map the message roles."""
|
|
||||||
output: List[Tuple[str, Optional[str]]] = []
|
|
||||||
for message in messages:
|
|
||||||
role = message["role"]
|
|
||||||
if role in role_map:
|
|
||||||
output.append((role_map[role], message["content"]))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _format_llama2(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the llama2 style."""
|
|
||||||
seps = [sep, sep2]
|
|
||||||
ret = system_message + sep
|
|
||||||
for i, (role, message) in enumerate(messages):
|
|
||||||
if system_message and i == 0:
|
|
||||||
ret += message + seps[i % 2]
|
|
||||||
elif message:
|
|
||||||
ret += role + message + " " + seps[i % 2]
|
|
||||||
else:
|
|
||||||
ret += role + " "
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def _format_add_colon_single(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the add-colon-single style."""
|
|
||||||
ret = system_message + sep
|
|
||||||
for role, message in messages:
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + sep
|
|
||||||
else:
|
|
||||||
ret += role + ":"
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def _format_add_colon_two(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the add-colon-two style."""
|
|
||||||
seps = [sep, sep2]
|
|
||||||
ret = system_message + seps[0]
|
|
||||||
for i, (role, message) in enumerate(messages):
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + seps[i % 2]
|
|
||||||
else:
|
|
||||||
ret += role + ":"
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def _format_no_colon_single(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the no-colon-single style."""
|
|
||||||
ret = system_message
|
|
||||||
for role, message in messages:
|
|
||||||
if message:
|
|
||||||
ret += role + message + sep
|
|
||||||
else:
|
|
||||||
ret += role
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def _format_add_colon_space_single(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the add-colon-space-single style."""
|
|
||||||
ret = system_message + sep
|
|
||||||
for role, message in messages:
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + sep
|
|
||||||
else:
|
|
||||||
ret += role + ": " # must be end with a space
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def _format_chatml(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the chatml style."""
|
|
||||||
ret = "" if system_message == "" else system_message + sep + "\n"
|
|
||||||
for role, message in messages:
|
|
||||||
if message:
|
|
||||||
ret += role + "\n" + message + sep + "\n"
|
|
||||||
else:
|
|
||||||
ret += role + "\n"
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _format_chatglm3(
|
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
||||||
) -> str:
|
|
||||||
"""Format the prompt with the chatglm3 style."""
|
|
||||||
ret = ""
|
|
||||||
if system_message:
|
|
||||||
ret += system_message
|
|
||||||
for role, message in messages:
|
|
||||||
if message:
|
|
||||||
ret += role + "\n" + " " + message
|
|
||||||
else:
|
|
||||||
ret += role
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ChatFormatterResponse:
|
class ChatFormatterResponse:
|
||||||
|
"""Dataclass that stores completion parameters for a given chat format and
|
||||||
|
create_chat_completion request.
|
||||||
|
|
||||||
|
prompt contains the formatted prompt generated from the chat format and messages.
|
||||||
|
stop contains the stop token or list of stop tokens to use for the chat format."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatter(Protocol):
|
class ChatFormatter(Protocol):
|
||||||
|
"""Base Protocol for a chat formatter. A chat formatter is a function that
|
||||||
|
takes a list of messages and returns a chat format response which can be used
|
||||||
|
to generate a completion. The response can also include a stop token or list
|
||||||
|
of stop tokens to use for the completion."""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -203,14 +146,52 @@ class ChatFormatter(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class BasicChatHandler:
|
class Jinja2ChatFormatter(ChatFormatter):
|
||||||
def __init__(self, chat_format: str):
|
def __init__(
|
||||||
self.chat_format = chat_format
|
self,
|
||||||
|
template: str,
|
||||||
|
eos_token: str,
|
||||||
|
bos_token: str,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
):
|
||||||
|
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||||
|
self.template = template
|
||||||
|
self.eos_token = eos_token
|
||||||
|
self.bos_token = bos_token
|
||||||
|
self.add_generation_prompt = add_generation_prompt
|
||||||
|
|
||||||
|
self._environment = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
).from_string(self.template)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
if self.add_generation_prompt:
|
||||||
|
messages = [
|
||||||
|
*messages,
|
||||||
|
llama_types.ChatCompletionRequestAssistantMessage(
|
||||||
|
role="assistant", content=""
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = self._environment.render(
|
||||||
|
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
||||||
|
)
|
||||||
|
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
||||||
|
|
||||||
|
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||||
|
return chat_formatter_to_chat_completion_handler(self)
|
||||||
|
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
completion: llama_types.Completion,
|
completion: llama_types.Completion,
|
||||||
) -> llama_types.ChatCompletion:
|
) -> llama_types.ChatCompletion:
|
||||||
|
assert "usage" in completion
|
||||||
return {
|
return {
|
||||||
"id": "chat" + completion["id"],
|
"id": "chat" + completion["id"],
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
@ -286,103 +267,85 @@ def _convert_completion_to_chat(
|
||||||
return _convert_text_completion_to_chat(completion)
|
return _convert_text_completion_to_chat(completion)
|
||||||
|
|
||||||
|
|
||||||
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
def chat_formatter_to_chat_completion_handler(
|
||||||
|
chat_formatter: ChatFormatter,
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
def register_chat_format(name: str):
|
def chat_completion_handler(
|
||||||
def decorator(f: ChatFormatter):
|
*,
|
||||||
def basic_create_chat_completion(
|
llama: llama.Llama,
|
||||||
*,
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
llama: llama.Llama,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||||
function_call: Optional[
|
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||||
llama_types.ChatCompletionRequestFunctionCall
|
temperature: float = 0.2,
|
||||||
] = None,
|
top_p: float = 0.95,
|
||||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
top_k: int = 40,
|
||||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
min_p: float = 0.05,
|
||||||
temperature: float = 0.2,
|
typical_p: float = 1.0,
|
||||||
top_p: float = 0.95,
|
stream: bool = False,
|
||||||
top_k: int = 40,
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
min_p: float = 0.05,
|
seed: Optional[int] = None,
|
||||||
typical_p: float = 1.0,
|
response_format: Optional[
|
||||||
stream: bool = False,
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
] = None,
|
||||||
seed: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
response_format: Optional[
|
presence_penalty: float = 0.0,
|
||||||
llama_types.ChatCompletionRequestResponseFormat
|
frequency_penalty: float = 0.0,
|
||||||
] = None,
|
repeat_penalty: float = 1.1,
|
||||||
max_tokens: Optional[int] = None,
|
tfs_z: float = 1.0,
|
||||||
presence_penalty: float = 0.0,
|
mirostat_mode: int = 0,
|
||||||
frequency_penalty: float = 0.0,
|
mirostat_tau: float = 5.0,
|
||||||
repeat_penalty: float = 1.1,
|
mirostat_eta: float = 0.1,
|
||||||
tfs_z: float = 1.0,
|
model: Optional[str] = None,
|
||||||
mirostat_mode: int = 0,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
mirostat_tau: float = 5.0,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
mirostat_eta: float = 0.1,
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
model: Optional[str] = None,
|
**kwargs, # type: ignore
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
) -> Union[
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
llama_types.CreateChatCompletionResponse,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
**kwargs, # type: ignore
|
]:
|
||||||
) -> Union[
|
result = chat_formatter(
|
||||||
llama_types.CreateChatCompletionResponse,
|
messages=messages,
|
||||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
functions=functions,
|
||||||
]:
|
function_call=function_call,
|
||||||
result = f(
|
|
||||||
messages=messages,
|
|
||||||
functions=functions,
|
|
||||||
function_call=function_call,
|
|
||||||
)
|
|
||||||
prompt = result.prompt
|
|
||||||
if result.stop is not None:
|
|
||||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
|
||||||
stop = stop + rstop
|
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
|
||||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
|
||||||
llama_grammar.JSON_GBNF
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_or_chunks = llama.create_completion(
|
|
||||||
prompt=prompt,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
min_p=min_p,
|
|
||||||
typical_p=typical_p,
|
|
||||||
stream=stream,
|
|
||||||
stop=stop,
|
|
||||||
seed=seed,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
repeat_penalty=repeat_penalty,
|
|
||||||
tfs_z=tfs_z,
|
|
||||||
mirostat_mode=mirostat_mode,
|
|
||||||
mirostat_tau=mirostat_tau,
|
|
||||||
mirostat_eta=mirostat_eta,
|
|
||||||
model=model,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
grammar=grammar,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
)
|
|
||||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
|
||||||
|
|
||||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
|
||||||
return f
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_format(name: str):
|
|
||||||
try:
|
|
||||||
return _CHAT_FORMATS[name]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
|
|
||||||
)
|
)
|
||||||
|
prompt = result.prompt
|
||||||
|
if result.stop is not None:
|
||||||
|
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
|
stop = stop + rstop
|
||||||
|
|
||||||
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||||
|
|
||||||
|
completion_or_chunks = llama.create_completion(
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
|
stream=stream,
|
||||||
|
stop=stop,
|
||||||
|
seed=seed,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
tfs_z=tfs_z,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
)
|
||||||
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||||
|
|
||||||
|
return chat_completion_handler
|
||||||
|
|
||||||
|
|
||||||
def hf_autotokenizer_to_chat_formatter(
|
def hf_autotokenizer_to_chat_formatter(
|
||||||
|
@ -391,22 +354,222 @@ def hf_autotokenizer_to_chat_formatter(
|
||||||
# https://huggingface.co/docs/transformers/main/chat_templating
|
# https://huggingface.co/docs/transformers/main/chat_templating
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
|
||||||
|
|
||||||
def format_autotokenizer(
|
def format_autotokenizer(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
tokenizer.use_default_system_prompt = False
|
tokenizer.use_default_system_prompt = False # type: ignore
|
||||||
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
|
||||||
|
assert isinstance(prompt, str)
|
||||||
# Return formatted prompt and eos token by default
|
# Return formatted prompt and eos token by default
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
|
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
|
||||||
|
|
||||||
return format_autotokenizer
|
return format_autotokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def hf_autotokenizer_to_chat_completion_handler(
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
|
||||||
|
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_tokenizer_config_to_chat_formatter(
|
||||||
|
tokenizer_config: Dict[str, Any]
|
||||||
|
) -> ChatFormatter:
|
||||||
|
assert isinstance(tokenizer_config, dict)
|
||||||
|
|
||||||
|
assert "chat_template" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["chat_template"], str)
|
||||||
|
chat_template = tokenizer_config["chat_template"]
|
||||||
|
|
||||||
|
assert "bos_token" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["bos_token"], str)
|
||||||
|
bos_token = tokenizer_config["bos_token"]
|
||||||
|
|
||||||
|
assert "eos_token" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["eos_token"], str)
|
||||||
|
eos_token = tokenizer_config["eos_token"]
|
||||||
|
|
||||||
|
env = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
).from_string(chat_template)
|
||||||
|
|
||||||
|
def format_autotokenizer(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
# TODO: veryify this is correct
|
||||||
|
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
|
||||||
|
prompt = env.render(
|
||||||
|
messages=[
|
||||||
|
*messages,
|
||||||
|
llama_types.ChatCompletionRequestAssistantMessage(
|
||||||
|
role="assistant", content=""
|
||||||
|
),
|
||||||
|
],
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
)
|
||||||
|
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
|
||||||
|
|
||||||
|
return format_autotokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def hf_tokenizer_config_to_chat_completion_handler(
|
||||||
|
tokenizer_config: Dict[str, Any],
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
|
||||||
|
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||||
|
|
||||||
|
|
||||||
|
### Utility functions for formatting chat prompts ###
|
||||||
|
|
||||||
|
|
||||||
|
def _get_system_message(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
) -> str:
|
||||||
|
"""Get the first system message."""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
return message["content"] or ""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _map_roles(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
role_map: Dict[str, str],
|
||||||
|
) -> List[Tuple[str, Optional[str]]]:
|
||||||
|
"""Map the message roles."""
|
||||||
|
output: List[Tuple[str, Optional[str]]] = []
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
if role in role_map:
|
||||||
|
content: str | None = (
|
||||||
|
message["content"] if isinstance(message["content"], str) else None
|
||||||
|
)
|
||||||
|
output.append((role_map[role], content))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _format_llama2(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the llama2 style."""
|
||||||
|
seps = [sep, sep2]
|
||||||
|
ret = system_message + sep
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if system_message and i == 0:
|
||||||
|
m = message or ""
|
||||||
|
ret += m + seps[i % 2]
|
||||||
|
elif message:
|
||||||
|
ret += role + message + " " + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + " "
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_add_colon_single(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the add-colon-single style."""
|
||||||
|
ret = system_message + sep
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + sep
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_add_colon_two(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the add-colon-two style."""
|
||||||
|
seps = [sep, sep2]
|
||||||
|
ret = system_message + seps[0]
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_no_colon_single(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the no-colon-single style."""
|
||||||
|
ret = system_message
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
ret += role + message + sep
|
||||||
|
else:
|
||||||
|
ret += role
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_add_colon_space_single(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the add-colon-space-single style."""
|
||||||
|
ret = system_message + sep
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + sep
|
||||||
|
else:
|
||||||
|
ret += role + ": " # must be end with a space
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_chatml(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the chatml style."""
|
||||||
|
ret = "" if system_message == "" else system_message + sep + "\n"
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
ret += role + "\n" + message + sep + "\n"
|
||||||
|
else:
|
||||||
|
ret += role + "\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _format_chatglm3(
|
||||||
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with the chatglm3 style."""
|
||||||
|
ret = ""
|
||||||
|
if system_message:
|
||||||
|
ret += system_message
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
ret += role + "\n" + " " + message
|
||||||
|
else:
|
||||||
|
ret += role
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
### Chat Formats ###
|
||||||
|
|
||||||
|
|
||||||
|
def register_chat_format(name: str):
|
||||||
|
def decorator(f: ChatFormatter):
|
||||||
|
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
|
||||||
|
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
|
||||||
|
name, chat_completion_handler
|
||||||
|
)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||||
# system prompt is "embedded" in the first message
|
# system prompt is "embedded" in the first message
|
||||||
@register_chat_format("llama-2")
|
@register_chat_format("llama-2")
|
||||||
|
@ -437,21 +600,23 @@ def format_alpaca(
|
||||||
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("qwen")
|
@register_chat_format("qwen")
|
||||||
def format_qwen(
|
def format_qwen(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
|
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
|
||||||
system_message="You are a helpful assistant."
|
system_message = "You are a helpful assistant."
|
||||||
system_template="<|im_start|>system\n{system_message}"
|
system_template = "<|im_start|>system\n{system_message}"
|
||||||
system_message=system_template.format(system_message=system_message)
|
system_message = system_template.format(system_message=system_message)
|
||||||
_messages = _map_roles(messages, _roles)
|
_messages = _map_roles(messages, _roles)
|
||||||
_messages.append((_roles["assistant"], None))
|
_messages.append((_roles["assistant"], None))
|
||||||
_sep = "<|im_end|>"
|
_sep = "<|im_end|>"
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
_sep2 = "<|endoftext|>"
|
_sep2 = "<|endoftext|>"
|
||||||
return ChatFormatterResponse(prompt=_prompt,stop=_sep2)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("vicuna")
|
@register_chat_format("vicuna")
|
||||||
def format(
|
def format(
|
||||||
|
@ -650,6 +815,7 @@ def format_mistrallite(
|
||||||
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("zephyr")
|
@register_chat_format("zephyr")
|
||||||
def format_zephyr(
|
def format_zephyr(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -699,6 +865,7 @@ def format_chatml(
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("chatglm3")
|
@register_chat_format("chatglm3")
|
||||||
def format_chatglm3(
|
def format_chatglm3(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -739,7 +906,7 @@ def format_openchat(
|
||||||
@register_chat_format("saiga")
|
@register_chat_format("saiga")
|
||||||
def format_saiga(
|
def format_saiga(
|
||||||
messages: list[llama_types.ChatCompletionRequestMessage],
|
messages: list[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
_message_template = "<s>{role}\n{content}</s>"
|
_message_template = "<s>{role}\n{content}</s>"
|
||||||
_roles = dict(user="user", bot="bot", system="system")
|
_roles = dict(user="user", bot="bot", system="system")
|
||||||
|
|
|
@ -91,6 +91,12 @@ c_float_p = POINTER(c_float)
|
||||||
c_uint8_p = POINTER(c_uint8)
|
c_uint8_p = POINTER(c_uint8)
|
||||||
c_size_t_p = POINTER(c_size_t)
|
c_size_t_p = POINTER(c_size_t)
|
||||||
|
|
||||||
|
# from ggml-backend.h
|
||||||
|
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||||
|
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
|
||||||
|
c_bool, c_void_p, c_bool, c_void_p
|
||||||
|
)
|
||||||
|
|
||||||
# llama.h bindings
|
# llama.h bindings
|
||||||
|
|
||||||
_lib.llama_max_devices.argtypes = []
|
_lib.llama_max_devices.argtypes = []
|
||||||
|
@ -448,6 +454,9 @@ class llama_model_params(Structure):
|
||||||
# float yarn_beta_slow; // YaRN high correction dim
|
# float yarn_beta_slow; // YaRN high correction dim
|
||||||
# uint32_t yarn_orig_ctx; // YaRN original context size
|
# uint32_t yarn_orig_ctx; // YaRN original context size
|
||||||
|
|
||||||
|
# ggml_backend_sched_eval_callback cb_eval;
|
||||||
|
# void * cb_eval_user_data;
|
||||||
|
|
||||||
# enum ggml_type type_k; // data type for K cache
|
# enum ggml_type type_k; // data type for K cache
|
||||||
# enum ggml_type type_v; // data type for V cache
|
# enum ggml_type type_v; // data type for V cache
|
||||||
|
|
||||||
|
@ -475,6 +484,8 @@ class llama_context_params(Structure):
|
||||||
yarn_beta_fast (float): YaRN low correction dim
|
yarn_beta_fast (float): YaRN low correction dim
|
||||||
yarn_beta_slow (float): YaRN high correction dim
|
yarn_beta_slow (float): YaRN high correction dim
|
||||||
yarn_orig_ctx (int): YaRN original context size
|
yarn_orig_ctx (int): YaRN original context size
|
||||||
|
cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
|
||||||
|
cb_eval_user_data (ctypes.c_void_p): user data for cb_eval
|
||||||
type_k (int): data type for K cache
|
type_k (int): data type for K cache
|
||||||
type_v (int): data type for V cache
|
type_v (int): data type for V cache
|
||||||
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||||
|
@ -497,6 +508,8 @@ class llama_context_params(Structure):
|
||||||
("yarn_beta_fast", c_float),
|
("yarn_beta_fast", c_float),
|
||||||
("yarn_beta_slow", c_float),
|
("yarn_beta_slow", c_float),
|
||||||
("yarn_orig_ctx", c_uint32),
|
("yarn_orig_ctx", c_uint32),
|
||||||
|
("cb_eval", ggml_backend_sched_eval_callback),
|
||||||
|
("cb_eval_user_data", c_void_p),
|
||||||
("type_k", c_int),
|
("type_k", c_int),
|
||||||
("type_v", c_int),
|
("type_v", c_int),
|
||||||
("mul_mat_q", c_bool),
|
("mul_mat_q", c_bool),
|
||||||
|
|
|
@ -1432,7 +1432,6 @@ class SchemaConverter:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
||||||
schema_type: Optional[str] = schema.get("type") # type: ignore
|
|
||||||
rule_name = name or "root"
|
rule_name = name or "root"
|
||||||
|
|
||||||
if "$defs" in schema:
|
if "$defs" in schema:
|
||||||
|
@ -1458,7 +1457,19 @@ class SchemaConverter:
|
||||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
||||||
return self._add_rule(rule_name, rule)
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
elif schema_type == "object" and "properties" in schema:
|
elif "$ref" in schema:
|
||||||
|
ref = schema["$ref"]
|
||||||
|
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
||||||
|
# inline $defs
|
||||||
|
def_name = ref[len("#/$defs/") :]
|
||||||
|
def_schema = self._defs[def_name]
|
||||||
|
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
||||||
|
|
||||||
|
|
||||||
|
schema_type: Optional[str] = schema.get("type") # type: ignore
|
||||||
|
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
|
||||||
|
|
||||||
|
if schema_type == "object" and "properties" in schema:
|
||||||
# TODO: `required` keyword
|
# TODO: `required` keyword
|
||||||
prop_order = self._prop_order
|
prop_order = self._prop_order
|
||||||
prop_pairs = sorted(
|
prop_pairs = sorted(
|
||||||
|
@ -1489,14 +1500,6 @@ class SchemaConverter:
|
||||||
)
|
)
|
||||||
return self._add_rule(rule_name, rule)
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
elif "$ref" in schema:
|
|
||||||
ref = schema["$ref"]
|
|
||||||
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
|
||||||
# inline $defs
|
|
||||||
def_name = ref[len("#/$defs/") :]
|
|
||||||
def_schema = self._defs[def_name]
|
|
||||||
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||||
return self._add_rule(
|
return self._add_rule(
|
||||||
|
|
|
@ -55,7 +55,7 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
|
||||||
raise ValueError(f"Invalid boolean argument: {arg}")
|
raise ValueError(f"Invalid boolean argument: {arg}")
|
||||||
|
|
||||||
|
|
||||||
def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
|
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
|
||||||
"""Add arguments from a pydantic model to an argparse parser."""
|
"""Add arguments from a pydantic model to an argparse parser."""
|
||||||
|
|
||||||
for name, field in model.model_fields.items():
|
for name, field in model.model_fields.items():
|
||||||
|
@ -83,7 +83,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=type[BaseModel])
|
T = TypeVar("T", bound=Type[BaseModel])
|
||||||
|
|
||||||
|
|
||||||
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
|
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from typing import Dict, Optional, Union, List
|
from typing import Dict, Optional, Union, List
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
@ -71,7 +73,25 @@ class LlamaProxy:
|
||||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
)
|
)
|
||||||
|
elif settings.chat_format == "hf-autotokenizer":
|
||||||
|
assert (
|
||||||
|
settings.hf_pretrained_model_name_or_path is not None
|
||||||
|
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_completion_handler(
|
||||||
|
settings.hf_pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif settings.chat_format == "hf-tokenizer-config":
|
||||||
|
assert (
|
||||||
|
settings.hf_tokenizer_config_path is not None
|
||||||
|
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
|
||||||
|
json.load(open(settings.hf_tokenizer_config_path))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||||
if settings.kv_overrides is not None:
|
if settings.kv_overrides is not None:
|
||||||
assert isinstance(settings.kv_overrides, list)
|
assert isinstance(settings.kv_overrides, list)
|
||||||
|
@ -141,4 +161,3 @@ class LlamaProxy:
|
||||||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||||
_model.set_cache(cache)
|
_model.set_cache(cache)
|
||||||
return _model
|
return _model
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class ModelSettings(BaseSettings):
|
||||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||||
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
||||||
offload_kqv: bool = Field(
|
offload_kqv: bool = Field(
|
||||||
default=False, description="Whether to offload kqv to the GPU."
|
default=True, description="Whether to offload kqv to the GPU."
|
||||||
)
|
)
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
last_n_tokens_size: int = Field(
|
last_n_tokens_size: int = Field(
|
||||||
|
@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
|
||||||
default=2 << 30,
|
default=2 << 30,
|
||||||
description="The size of the cache in bytes. Only used if cache is True.",
|
description="The size of the cache in bytes. Only used if cache is True.",
|
||||||
)
|
)
|
||||||
|
# Tokenizer Options
|
||||||
|
hf_tokenizer_config_path: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The path to a HuggingFace tokenizer_config.json file.",
|
||||||
|
)
|
||||||
|
hf_pretrained_model_name_or_path: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||||
|
)
|
||||||
# Misc
|
# Misc
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=True, description="Whether to print debug information."
|
default=True, description="Whether to print debug information."
|
||||||
|
|
|
@ -15,6 +15,7 @@ dependencies = [
|
||||||
"typing-extensions>=4.5.0",
|
"typing-extensions>=4.5.0",
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"diskcache>=5.6.1",
|
"diskcache>=5.6.1",
|
||||||
|
"jinja2>=2.11.3",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
@ -72,4 +73,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--ignore=vendor"
|
addopts = "--ignore=vendor"
|
||||||
|
|
||||||
|
|
|
@ -50,3 +50,29 @@ def test_composed_pydantic_grammar():
|
||||||
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
|
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
|
||||||
|
|
||||||
assert grammar.grammar is not None
|
assert grammar.grammar is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_grammar_anyof():
|
||||||
|
sch = {
|
||||||
|
"properties": {
|
||||||
|
"temperature": {
|
||||||
|
"description": "The temperature mentioned",
|
||||||
|
"type": "number",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"description": "Unit for temperature",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
{"type": "null"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
|
||||||
|
|
||||||
|
assert grammar.grammar is not None
|
65
tests/test_llama_chat_format.py
Normal file
65
tests/test_llama_chat_format.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from llama_cpp import (
|
||||||
|
ChatCompletionRequestUserMessage,
|
||||||
|
)
|
||||||
|
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
|
||||||
|
|
||||||
|
|
||||||
|
mistral_7b_tokenizer_config = """{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [],
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"legacy": true,
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": null,
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": false,
|
||||||
|
"tokenizer_class": "LlamaTokenizer",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"use_default_system_prompt": false,
|
||||||
|
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_tokenizer_config_str_to_chat_formatter():
|
||||||
|
tokenizer_config = json.loads(mistral_7b_tokenizer_config)
|
||||||
|
chat_formatter = hf_tokenizer_config_to_chat_formatter(
|
||||||
|
tokenizer_config
|
||||||
|
)
|
||||||
|
chat_formatter_respoonse = chat_formatter(
|
||||||
|
messages=[
|
||||||
|
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "")
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 5c999609013a30c06e6fd28be8db5c2074bcc196
|
Subproject commit 504dc37be8446fb09b1ede70300250ad41be32a2
|
Loading…
Reference in a new issue