This commit is contained in:
baalajimaestro 2024-04-25 10:48:33 +05:30
commit ce85be97e2
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
14 changed files with 1201 additions and 191 deletions

View file

@ -41,6 +41,35 @@ jobs:
with:
path: ./wheelhouse/*.whl
build_arm64_wheels:
name: Build arm64 wheels
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: "recursive"
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
platforms: linux/arm64
- name: Build wheels
uses: pypa/cibuildwheel@v2.16.5
env:
CIBW_SKIP: "*musllinux* pp*"
CIBW_REPAIR_WHEEL_COMMAND: ""
CIBW_ARCHS: "aarch64"
CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*"
with:
output-dir: wheelhouse/
- name: Upload wheels as artifacts
uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.version }}
path: wheelhouse/*.whl
build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
@ -65,7 +94,7 @@ jobs:
release:
name: Release
needs: [build_wheels, build_sdist]
needs: [build_wheels, build_arm64_wheels, build_sdist]
runs-on: ubuntu-latest
steps:

View file

@ -7,6 +7,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.64]
- feat: Update llama.cpp to ggerganov/llama.cpp@4e96a812b3ce7322a29a3008db2ed73d9087b176
- feat: Add `llama-3` chat format by @andreabak in #1371
- feat: Use new llama_token_is_eog in create_completions by @abetlen in d40a250ef3cfaa8224d12c83776a2f1de96ae3d1
- feat(server): Provide ability to dynamically allocate all threads if desired using -1 by @sean-bailey in #1364
- ci: Build arm64 wheels by @gaby in 611781f5319719a3d05fefccbbf0cc321742a026
- fix: Update scikit-build-core build dependency avoid bug in 0.9.1 by @evelkey in #1370
## [0.2.63]
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
## [0.2.62]
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c
- feat: update grammar schema converter to match llama.cpp by @themrzmaster in #1353
- feat: add disable_ping_events flag by @khimaros in #1257
- feat: Make saved state more compact on-disk by @tc-wolf in #1296
- feat: Use all available CPUs for batch processing by @ddh0 in #1345
## [0.2.61]
- feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676
- fix: pass correct type to chat handlers for chat completion logprobs by @abetlen in bb65b4d76411112c6fb0bf759efd746f99ef3c6b
- feat: Add support for yaml based server configs by @abetlen in 060bfa64d529ade2af9b1f4e207a3937bbc4138f
- feat: Add typechecking for ctypes structure attributes by @abetlen in 1347e1d050fc5a9a32ffe0bb3e22858da28003bd
## [0.2.60]
- feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0

View file

@ -0,0 +1,30 @@
"""llama-cpp-python server from scratch in a single file.
"""
# import llama_cpp
# path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf"
# model_params = llama_cpp.llama_model_default_params()
# model = llama_cpp.llama_load_model_from_file(path, model_params)
# if model is None:
# raise RuntimeError(f"Failed to load model from file: {path}")
# ctx_params = llama_cpp.llama_context_default_params()
# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params)
# if ctx is None:
# raise RuntimeError("Failed to create context")
from fastapi import FastAPI
app = FastAPI()
import openai.types.chat as types
@app.post("/v1/chat/completions")
def create_chat_completions():
return {"message": "Hello World"}

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.60"
__version__ = "0.2.64"

View file

@ -181,20 +181,20 @@ class _LlamaModel:
)
return list(tokens[:n_tokens])
def token_to_piece(self, token: int) -> bytes:
def token_to_piece(self, token: int, special: bool = False) -> bytes:
assert self.model is not None
buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special)
return bytes(buf)
def detokenize(self, tokens: List[int]) -> bytes:
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
assert self.model is not None
output = b""
size = 32
buffer = (ctypes.c_char * size)()
for token in tokens:
n = llama_cpp.llama_token_to_piece(
self.model, llama_cpp.llama_token(token), buffer, size
self.model, llama_cpp.llama_token(token), buffer, size, special
)
assert n <= size
output += bytes(buffer[:n])
@ -597,13 +597,13 @@ def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> li
return list(result)
def _token_to_piece(model: _LlamaModel, token: int) -> str:
def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str:
assert model.model is not None
result = (ctypes.c_char * 8)(0)
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
if n_tokens < 0:
result = (ctypes.c_char * -n_tokens)(0)
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
if check != -n_tokens:
raise RuntimeError(f"Failed to get piece: token={token}")
else:

View file

@ -18,6 +18,7 @@ from typing import (
Iterator,
Deque,
Callable,
Dict,
)
from collections import deque
from pathlib import Path
@ -262,9 +263,7 @@ class Llama:
self.n_batch = min(n_ctx, n_batch) # ???
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
@ -427,7 +426,10 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template, eos_token=eos_token, bos_token=bos_token
template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
@ -1032,7 +1034,8 @@ class Llama:
logits_processor=logits_processor,
grammar=grammar,
):
if token == self._token_eos:
assert self._model.model is not None
if llama_cpp.llama_token_is_eog(self._model.model, token):
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop"
break
@ -1664,7 +1667,8 @@ class Llama:
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
logprobs=logprobs,
top_logprobs=top_logprobs,
stream=stream,
stop=stop,
seed=seed,
@ -1792,7 +1796,7 @@ class Llama:
file=sys.stderr,
)
return LlamaState(
scores=self.scores.copy(),
scores=self._scores.copy(),
input_ids=self.input_ids.copy(),
n_tokens=self.n_tokens,
llama_state=bytes(llama_state_compact),
@ -1801,7 +1805,9 @@ class Llama:
def load_state(self, state: LlamaState) -> None:
assert self._ctx.ctx is not None
self.scores = state.scores.copy()
# Only filling in up to `n_tokens` and then zero-ing out the rest
self.scores[: state.n_tokens, :] = state.scores.copy()
self.scores[state.n_tokens :, :] = 0.0
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
state_size = state.llama_state_size
@ -1952,7 +1958,6 @@ class Llama:
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
local_files_only=True,
)
else:
model_path = os.path.join(local_dir, filename)

View file

@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
import jinja2
import numpy as np
import numpy.typing as npt
import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
@ -32,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
### Chat Completion Handler ###
@ -77,6 +83,8 @@ class LlamaChatCompletionHandler(Protocol):
mirostat_eta: float = 0.1,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@ -148,6 +156,7 @@ class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
class ChatFormatter(Protocol):
@ -171,12 +180,14 @@ class Jinja2ChatFormatter(ChatFormatter):
eos_token: str,
bos_token: str,
add_generation_prompt: bool = True,
stop_token_ids: Optional[List[int]] = None,
):
"""A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template
self.eos_token = eos_token
self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
@ -209,7 +220,16 @@ class Jinja2ChatFormatter(ChatFormatter):
tool_choice=tool_choice,
)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
stopping_criteria = None
if self.stop_token_ids is not None:
def stop_on_last_token(
tokens: npt.NDArray[np.intc],
logits: npt.NDArray[np.single]
) -> bool:
return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self)
@ -338,7 +358,7 @@ def _convert_completion_to_chat_function(
}
],
},
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls",
}
],
@ -391,7 +411,7 @@ def _convert_completion_to_chat_function(
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
@ -426,7 +446,7 @@ def _convert_completion_to_chat_function(
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
@ -491,7 +511,6 @@ def chat_formatter_to_chat_completion_handler(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
logprobs: int = 0,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
@ -512,6 +531,8 @@ def chat_formatter_to_chat_completion_handler(
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@ -530,6 +551,10 @@ def chat_formatter_to_chat_completion_handler(
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
stopping_criteria = None
if result.stopping_criteria is not None:
stopping_criteria = result.stopping_criteria
if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
@ -581,7 +606,7 @@ def chat_formatter_to_chat_completion_handler(
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
logprobs=top_logprobs if logprobs else None,
stream=stream,
stop=stop,
seed=seed,
@ -595,6 +620,7 @@ def chat_formatter_to_chat_completion_handler(
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
grammar=grammar,
logit_bias=logit_bias,
)
@ -706,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
return "mistral-instruct"
if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
return "llama-3"
return None
@ -897,6 +926,26 @@ def format_llama2(
return ChatFormatterResponse(prompt=_prompt)
# Chat format for Llama-3 models, see more details at:
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
@register_chat_format("llama-3")
def format_llama3(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(
system="<|start_header_id|>system<|end_header_id|>\n\n",
user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@register_chat_format("alpaca")
def format_alpaca(
messages: List[llama_types.ChatCompletionRequestMessage],
@ -1628,7 +1677,7 @@ def functionary_chat_handler(
}
],
},
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls",
}
],
@ -2085,7 +2134,7 @@ def functionary_v1_v2_chat_handler(
choices=[
{
"index": 0,
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"message": {
"role": "assistant",
"content": None if content == "" else content,
@ -2311,11 +2360,14 @@ def chatml_function_calling(
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
print(logprobs)
function_calling_template = (
"{% for message in messages %}"
"<|im_start|>{{ message.role }}\n"
@ -2437,6 +2489,7 @@ def chatml_function_calling(
model=model,
logits_processor=logits_processor,
grammar=grammar,
logprobs=top_logprobs if logprobs else None,
),
stream=stream,
)
@ -2549,6 +2602,7 @@ def chatml_function_calling(
typical_p=typical_p,
stream=stream,
stop=["<|im_end|>"],
logprobs=top_logprobs if logprobs else None,
max_tokens=None,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
@ -2660,7 +2714,7 @@ def chatml_function_calling(
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"message": {
"role": "assistant",
"content": None,
@ -2701,4 +2755,4 @@ def chatml_function_calling(
},
}
raise ValueError("Automatic streaming tool choice is not supported")
raise ValueError("Automatic streaming tool choice is not supported")

View file

@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
# define LLAMA_SESSION_VERSION 5
LLAMA_SESSION_VERSION = 5
# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
# define LLAMA_STATE_SEQ_VERSION 1
LLAMA_STATE_SEQ_VERSION = 1
# struct llama_model;
llama_model_p = NewType("llama_model_p", int)
@ -424,6 +431,11 @@ class llama_token_data(ctypes.Structure):
logit (float): log-odds of the token
p (float): probability of the token"""
if TYPE_CHECKING:
id: llama_token
logit: float
p: float
_fields_ = [
("id", llama_token),
("logit", ctypes.c_float),
@ -447,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
size (int): size of the array
sorted (bool): whether the array is sorted"""
if TYPE_CHECKING:
data: CtypesArray[llama_token_data]
size: int
sorted: bool
_fields_ = [
("data", llama_token_data_p),
("size", ctypes.c_size_t),
@ -508,6 +525,15 @@ class llama_batch(ctypes.Structure):
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
"""
if TYPE_CHECKING:
n_tokens: int
token: CtypesArray[llama_token]
embd: CtypesArray[ctypes.c_float]
pos: CtypesArray[CtypesArray[llama_pos]]
n_seq_id: CtypesArray[ctypes.c_int]
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
logits: CtypesArray[ctypes.c_int8]
_fields_ = [
("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)),
@ -602,6 +628,18 @@ class llama_model_params(ctypes.Structure):
use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM"""
if TYPE_CHECKING:
n_gpu_layers: int
split_mode: int
main_gpu: int
tensor_split: CtypesArray[ctypes.c_float]
progress_callback: Callable[[float, ctypes.c_void_p], bool]
progress_callback_user_data: ctypes.c_void_p
kv_overrides: CtypesArray[llama_model_kv_override]
vocab_only: bool
use_mmap: bool
use_mlock: bool
_fields_ = [
("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int),
@ -689,6 +727,34 @@ class llama_context_params(ctypes.Structure):
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
"""
if TYPE_CHECKING:
seed: int
n_ctx: int
n_batch: int
n_ubatch: int
n_seq_max: int
n_threads: int
n_threads_batch: int
rope_scaling_type: int
pooling_type: int
rope_freq_base: float
rope_freq_scale: float
yarn_ext_factor: float
yarn_attn_factor: float
yarn_beta_fast: float
yarn_beta_slow: float
yarn_orig_ctx: int
defrag_thold: float
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
cb_eval_user_data: ctypes.c_void_p
type_k: int
type_v: int
logits_all: bool
embeddings: bool
offload_kqv: bool
abort_callback: Callable[[ctypes.c_void_p], bool]
abort_callback_data: ctypes.c_void_p
_fields_ = [
("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32),
@ -764,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
"""
if TYPE_CHECKING:
nthread: int
ftype: int
output_tensor_type: int
token_embedding_type: int
allow_requantize: bool
quantize_output_tensor: bool
only_copy: bool
pure: bool
imatrix: ctypes.c_void_p
kv_overrides: ctypes.c_void_p
_fields_ = [
("nthread", ctypes.c_int32),
("ftype", ctypes.c_int),
@ -821,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
# uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element;
class llama_grammar_element(ctypes.Structure):
if TYPE_CHECKING:
type: int
value: int
_fields_ = [
("type", ctypes.c_int),
("value", ctypes.c_uint32),
@ -844,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
# int32_t n_eval;
# };
class llama_timings(ctypes.Structure):
if TYPE_CHECKING:
t_start_ms: float
t_end_ms: float
t_load_ms: float
t_sample_ms: float
t_p_eval_ms: float
t_eval_ms: float
n_sample: int
n_p_eval: int
n_eval: int
_fields_ = [
("t_start_ms", ctypes.c_double),
("t_end_ms", ctypes.c_double),
@ -944,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int],
None,
)
def llama_numa_init(numa: int, /): ...
def llama_numa_init(numa: int, /):
...
# // Call once at the end of the program - currently only used for MPI
@ -969,7 +1063,8 @@ def llama_backend_free():
)
def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]: ...
) -> Optional[llama_model_p]:
...
# LLAMA_API void llama_free_model(struct llama_model * model);
@ -978,7 +1073,8 @@ def llama_load_model_from_file(
[llama_model_p_ctypes],
None,
)
def llama_free_model(model: llama_model_p, /): ...
def llama_free_model(model: llama_model_p, /):
...
# LLAMA_API struct llama_context * llama_new_context_with_model(
@ -991,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
)
def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]: ...
) -> Optional[llama_context_p]:
...
# // Frees all allocated memory
@ -1012,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
[],
ctypes.c_int64,
)
def llama_time_us() -> int: ...
def llama_time_us() -> int:
...
# LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int: ...
def llama_max_devices() -> int:
...
# LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool: ...
def llama_supports_mmap() -> bool:
...
# LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool: ...
def llama_supports_mlock() -> bool:
...
# LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool: ...
def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
def llama_n_batch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
def llama_n_ubatch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
def llama_n_seq_max(ctx: llama_context_p, /) -> int:
...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int: ...
def llama_vocab_type(model: llama_model_p, /) -> int:
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int: ...
def llama_rope_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: ...
def llama_n_vocab(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int: ...
def llama_n_embd(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_layer(model: llama_model_p, /) -> int: ...
def llama_n_layer(model: llama_model_p, /) -> int:
...
# // Get the model's RoPE frequency scaling factor
@ -1351,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
May be negative if the cell is not populated."""
if TYPE_CHECKING:
pos: llama_pos
_fields_ = [("pos", llama_pos)]
@ -1387,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# llama_seq_id * cells_sequences;
# };
class llama_kv_cache_view(ctypes.Structure):
if TYPE_CHECKING:
n_cells: int
n_max_seq: int
token_count: int
used_cells: int
max_contiguous: int
max_contiguous_idx: int
cells: CtypesArray[llama_kv_cache_view_cell]
cells_sequences: CtypesArray[llama_seq_id]
_fields_ = [
("n_cells", ctypes.c_int32),
("n_max_seq", ctypes.c_int32),
@ -1467,6 +1593,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
# // seq_id < 0 : match any sequence
# // p0 < 0 : [0, p1]
# // p1 < 0 : [p0, inf)
@ -1493,6 +1620,9 @@ def llama_kv_cache_seq_rm(
/,
) -> bool:
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
seq_id < 0 : match any sequence
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
@ -1652,7 +1782,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
# Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t)
def llama_state_get_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding
and kv_cache) - will often be smaller after compacting tokens"""
...
# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
# "use llama_state_get_size instead");
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
def llama_get_state_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding
@ -1663,9 +1802,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
# Copies the state to the specified destination address.
# Destination needs to have allocated enough memory.
# Returns the number of bytes copied
# LLAMA_API size_t llama_copy_state_data(
# LLAMA_API size_t llama_state_get_data(
# struct llama_context * ctx,
# uint8_t * dst);
@ctypes_function(
"llama_state_get_data",
[
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_uint8),
],
ctypes.c_size_t,
)
def llama_state_get_data(
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Copies the state to the specified destination address.
Destination needs to have allocated enough memory.
Returns the number of bytes copied"""
...
# LLAMA_API DEPRECATED(size_t llama_copy_state_data(
# struct llama_context * ctx,
# uint8_t * dst),
# "use llama_state_get_data instead");
@ctypes_function(
"llama_copy_state_data",
[
@ -1685,9 +1845,26 @@ def llama_copy_state_data(
# // Set the state reading from the specified address
# // Returns the number of bytes read
# LLAMA_API size_t llama_set_state_data(
# LLAMA_API size_t llama_state_set_data(
# struct llama_context * ctx,
# const uint8_t * src);
@ctypes_function(
"llama_state_set_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
ctypes.c_size_t,
)
def llama_state_set_data(
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Set the state reading from the specified address
Returns the number of bytes read"""
...
# LLAMA_API DEPRECATED(size_t llama_set_state_data(
# struct llama_context * ctx,
# const uint8_t * src),
# "use llama_state_set_data instead");
@ctypes_function(
"llama_set_state_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
@ -1701,12 +1878,41 @@ def llama_set_state_data(
# Save/load session file
# LLAMA_API bool llama_load_session_file(
# LLAMA_API bool llama_state_load_file(
# struct llama_context * ctx,
# const char * path_session,
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out);
@ctypes_function(
"llama_state_load_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_token_p,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_size_t),
],
ctypes.c_bool,
)
def llama_state_load_file(
ctx: llama_context_p,
path_session: bytes,
tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_load_session_file(
# struct llama_context * ctx,
# const char * path_session,
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out),
# "use llama_state_load_file instead");
@ctypes_function(
"llama_load_session_file",
[
@ -1725,14 +1931,41 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int: ...
) -> int:
...
# LLAMA_API bool llama_save_session_file(
# LLAMA_API bool llama_state_save_file(
# struct llama_context * ctx,
# const char * path_session,
# const llama_token * tokens,
# size_t n_token_count);
@ctypes_function(
"llama_state_save_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_token_p,
ctypes.c_size_t,
],
ctypes.c_bool,
)
def llama_state_save_file(
ctx: llama_context_p,
path_session: bytes,
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_save_session_file(
# struct llama_context * ctx,
# const char * path_session,
# const llama_token * tokens,
# size_t n_token_count),
# "use llama_state_save_file instead");
@ctypes_function(
"llama_save_session_file",
[
@ -1749,7 +1982,118 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> int: ...
) -> int:
...
# // Get the exact size needed to copy the KV cache of a single sequence
# LLAMA_API size_t llama_state_seq_get_size(
# struct llama_context * ctx,
# llama_seq_id seq_id);
@ctypes_function(
"llama_state_seq_get_size",
[llama_context_p_ctypes, llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
"""Get the exact size needed to copy the KV cache of a single sequence"""
...
# // Copy the KV cache of a single sequence into the specified buffer
# LLAMA_API size_t llama_state_seq_get_data(
# struct llama_context * ctx,
# uint8_t * dst,
# llama_seq_id seq_id);
@ctypes_function(
"llama_state_seq_get_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_get_data(
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, /
) -> int:
"""Copy the KV cache of a single sequence into the specified buffer"""
...
# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
# // Returns:
# // - Positive: Ok
# // - Zero: Failed to load
# LLAMA_API size_t llama_state_seq_set_data(
# struct llama_context * ctx,
# const uint8_t * src,
# llama_seq_id dest_seq_id);
@ctypes_function(
"llama_state_seq_set_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_set_data(
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, /
) -> int:
"""Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
...
# LLAMA_API size_t llama_state_seq_save_file(
# struct llama_context * ctx,
# const char * filepath,
# llama_seq_id seq_id,
# const llama_token * tokens,
# size_t n_token_count);
@ctypes_function(
"llama_state_seq_save_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_seq_id,
llama_token_p,
ctypes.c_size_t,
],
ctypes.c_size_t,
)
def llama_state_seq_save_file(
ctx: llama_context_p,
filepath: bytes,
seq_id: llama_seq_id,
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> int:
...
# LLAMA_API size_t llama_state_seq_load_file(
# struct llama_context * ctx,
# const char * filepath,
# llama_seq_id dest_seq_id,
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out);
@ctypes_function(
"llama_state_seq_load_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_seq_id,
llama_token_p,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_size_t),
],
ctypes.c_size_t,
)
def llama_state_seq_load_file(
ctx: llama_context_p,
filepath: bytes,
dest_seq_id: llama_seq_id,
tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int:
...
# //
@ -1930,8 +2274,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
...
# // Logits for the ith token. Equivalent to:
# // Logits for the ith token. For positive indices, Equivalent to:
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
# // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
# // returns NULL for invalid ids.
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
@ctypes_function(
@ -1963,8 +2308,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
...
# // Get the embeddings for the ith token. Equivalent to:
# // Get the embeddings for the ith token. For positive indices, Equivalent to:
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
# // shape: [n_embd] (1-dimensional)
# // returns NULL for invalid ids.
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@ -2010,7 +2356,8 @@ def llama_get_embeddings_seq(
)
def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], /
) -> bytes: ...
) -> bytes:
...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -2019,7 +2366,8 @@ def llama_token_get_text(
)
def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], /
) -> float: ...
) -> float:
...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -2028,7 +2376,20 @@ def llama_token_get_score(
)
def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], /
) -> int: ...
) -> int:
...
# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
@ctypes_function(
"llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool
)
def llama_token_is_eog(
model: llama_model_p, token: Union[llama_token, int], /
) -> bool:
"""Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)"""
...
# // Special tokens
@ -2048,6 +2409,20 @@ def llama_token_eos(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token)
def llama_token_cls(model: llama_model_p, /) -> int:
"""classification"""
...
# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token)
def llama_token_sep(model: llama_model_p, /) -> int:
"""sentence separator"""
...
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
def llama_token_nl(model: llama_model_p, /) -> int:
@ -2071,7 +2446,7 @@ def llama_add_eos_token(model: llama_model_p, /) -> int:
...
# // codellama infill tokens
# // Codellama infill tokens
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
def llama_token_prefix(model: llama_model_p) -> int:
@ -2081,17 +2456,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int: ...
def llama_token_middle(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int: ...
def llama_token_suffix(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int: ...
def llama_token_eot(model: llama_model_p, /) -> int:
...
# //
@ -2103,16 +2481,16 @@ def llama_token_eot(model: llama_model_p, /) -> int: ...
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
# /// @return Returns the number of tokens on success, no more than n_tokens_max
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
# /// Does not insert a leading space.
# /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
# /// as plaintext. Does not insert a leading space.
# LLAMA_API int32_t llama_tokenize(
# const struct llama_model * model,
# const char * text,
# int32_t text_len,
# llama_token * tokens,
# int32_t n_tokens_max,
# bool add_bos,
# bool special);
# bool add_special,
# bool parse_special);
@ctypes_function(
"llama_tokenize",
[
@ -2132,8 +2510,8 @@ def llama_tokenize(
text_len: Union[ctypes.c_int, int],
tokens: CtypesArray[llama_token],
n_tokens_max: Union[ctypes.c_int, int],
add_bos: Union[ctypes.c_bool, bool],
special: Union[ctypes.c_bool, bool],
add_special: Union[ctypes.c_bool, bool],
parse_special: Union[ctypes.c_bool, bool],
/,
) -> int:
"""Convert the provided text into tokens.
@ -2144,9 +2522,8 @@ def llama_tokenize(
text_len: The length of the text.
tokens: The tokens pointer must be large enough to hold the resulting tokens.
n_max_tokens: The maximum number of tokens to return.
add_bos: Whether to add a beginning-of-sentence token.
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
Does not insert a leading space.
add_special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
parse_special: Allow parsing special tokens.
Returns:
Returns the number of tokens on success, no more than n_tokens_max
@ -2159,11 +2536,13 @@ def llama_tokenize(
# // Uses the vocabulary in the provided context.
# // Does not write null terminator to the buffer.
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
# // @param special If true, special tokens are rendered in the output.
# LLAMA_API int32_t llama_token_to_piece(
# const struct llama_model * model,
# llama_token token,
# char * buf,
# int32_t length);
# int32_t length,
# bool special);
@ctypes_function(
"llama_token_to_piece",
[
@ -2171,6 +2550,7 @@ def llama_tokenize(
llama_token,
ctypes.c_char_p,
ctypes.c_int32,
ctypes.c_bool,
],
ctypes.c_int32,
)
@ -2179,13 +2559,20 @@ def llama_token_to_piece(
token: Union[llama_token, int],
buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
length: Union[ctypes.c_int, int],
special: Union[ctypes.c_bool, bool],
/,
) -> int:
"""Token Id -> Piece.
Uses the vocabulary in the provided context.
Does not write null terminator to the buffer.
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
"""
Args:
model: The model to use for tokenization.
token: The token to convert.
buf: The buffer to write the token to.
length: The length of the buffer.
special: If true, special tokens are rendered in the output."""
...
@ -2223,7 +2610,8 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message],
n_msg: int,
/,
) -> int: ...
) -> int:
...
# //
@ -2753,6 +3141,12 @@ def llama_grammar_accept_token(
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
# };
class llama_beam_view(ctypes.Structure):
if TYPE_CHECKING:
tokens: CtypesArray[llama_token]
n_tokens: int
p: float
eob: bool
_fields_ = [
("tokens", llama_token_p),
("n_tokens", ctypes.c_size_t),
@ -2772,6 +3166,12 @@ class llama_beam_view(ctypes.Structure):
# bool last_call; // True iff this is the last callback invocation.
# };
class llama_beams_state(ctypes.Structure):
if TYPE_CHECKING:
beam_views: CtypesArray[llama_beam_view]
n_beams: int
common_prefix_length: int
last_call: bool
_fields_ = [
("beam_views", ctypes.POINTER(llama_beam_view)),
("n_beams", ctypes.c_size_t),
@ -2824,7 +3224,8 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int],
/,
): ...
):
...
# /// @details Build a split GGUF final path for this chunk.
@ -2943,4 +3344,5 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes],
None,
)
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
...

View file

@ -5,11 +5,12 @@ from pathlib import Path
import sys
from ctypes import * # type: ignore
from enum import Enum
from itertools import islice
from itertools import islice, groupby
from typing import (
Any,
Callable,
Dict,
Set,
Generic,
List,
Optional,
@ -1391,145 +1392,561 @@ from typing import List, Optional
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
PRIMITIVE_RULES = {
"boolean": '("true" | "false") space',
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
"string": r""" "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space """,
"null": '"null" space',
}
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
if not separator_rule:
if min_items == 0 and max_items == 1:
return f'{item_rule}?'
elif min_items == 1 and max_items is None:
return f'{item_rule}+'
result = ''
if min_items > 0:
if item_rule_is_literal and separator_rule is None:
result = '"' + (item_rule[1:-1] * min_items) + '"'
else:
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
def opt_repetitions(up_to_n, prefix_with_sep=False):
'''
- n=4, no sep: '(a (a (a (a)?)?)?)?'
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
'''
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
if up_to_n == 0:
return ''
elif up_to_n == 1:
return f'({content})?'
elif separator_rule and not prefix_with_sep:
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
else:
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
if min_items > 0 and max_items != min_items:
result += ' '
if max_items is not None:
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
else:
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
if min_items == 0 and separator_rule:
result = f'({item_rule} {item_operator}*)?'
else:
result += f'{item_operator}*'
return result
class BuiltinRule:
def __init__(self, content: str, deps: list = None):
self.content = content
self.deps = deps or []
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
}
# TODO: support "uri", "email" string formats
STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}
DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
class SchemaConverter:
def __init__(self, prop_order):
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order
self._rules = {"space": SPACE_RULE}
self._defs: Dict[str, Any] = {}
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal: str):
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
)
return f'"{escaped}"'
def _add_rule(self, name: str, rule: str):
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
'''
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
'''
assert len(literal) > 0, 'Empty literal not supported'
def recurse(i: int):
c = literal[i]
if maybe_escaped_underscores and c == '_':
yield f'[^{c}\\\\]'
yield ' | '
yield f'"\\\\"? "{c}"'
else:
yield f'[^{c}]'
if i < len(literal) - 1:
yield ' | '
yield self._format_literal(c)
yield ' ('
yield from recurse(i + 1)
yield ')?'
return ''.join(('(', *recurse(0), ')'))
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
key = esc_name
else:
i = 0
while f"{esc_name}{i}" in self._rules:
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
i += 1
key = f"{esc_name}{i}"
key = f'{esc_name}{i}'
self._rules[key] = rule
return key
def visit(self, schema: Dict[str, Any], name: str) -> str:
rule_name = name or "root"
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
if "$defs" in schema:
# add defs to self._defs for later inlining
for def_name, def_schema in schema["$defs"].items():
self._defs[def_name] = def_schema
frag_split = ref.split('#')
base_url = frag_split[0]
if "oneOf" in schema or "anyOf" in schema:
rule = " | ".join(
(
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
for i, alt_schema in enumerate(
schema.get("oneOf") or schema["anyOf"]
)
)
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: Tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
def transform() -> Tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = DOTALL
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = DOT
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in groupby(seq, lambda x: x[1]):
if is_literal:
ret.append((''.join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({to_rule(transform())})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
(sub, sub_is_literal) = seq[-1]
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
else:
literal = ''
while i < length:
if pattern[i] == '\\' and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i:i+2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and \
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
def visit(self, schema, name):
schema_type = schema.get('type')
schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if (ref := schema.get('$ref')) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
elif 'oneOf' in schema or 'anyOf' in schema:
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
elif isinstance(schema_type, list):
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema:
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space ' +
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_primitive(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
)
return self._add_rule(rule_name, rule)
elif "const" in schema:
return self._add_rule(rule_name, self._format_literal(schema["const"]))
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
prim_name = f'{schema_format}-string'
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
elif "enum" in schema:
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule)
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')
elif "$ref" in schema:
ref = schema["$ref"]
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
# inline $defs
def_name = ref[len("#/$defs/") :]
def_schema = self._defs[def_name]
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
schema_type: Optional[str] = schema.get("type") # type: ignore
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
if schema_type == "object" and "properties" in schema:
# TODO: `required` keyword
if self._prop_order:
prop_order = self._prop_order
prop_pairs = sorted(
schema["properties"].items(),
# sort by position in prop_order (if specified) then by key
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
)
else:
prop_pairs = schema["properties"].items()
rule = '"{" space'
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
prop_rule_name = self.visit(
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
)
if i > 0:
rule += ' "," space'
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
return self._add_rule(rule_name, rule)
elif schema_type == "array" and "items" in schema:
# TODO `prefixItems` keyword
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
list_item_operator = f'("," space {item_rule_name})'
successive_items = ""
min_items = schema.get("minItems", 0)
if min_items > 0:
first_item = f"({item_rule_name})"
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
else:
first_item = f"({item_rule_name})?"
max_items = schema.get("maxItems")
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
rule = f'"[" space {first_item} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
elif (schema_type == 'object') or (len(schema) == 0):
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
return self._add_rule(
"root" if rule_name == "root" else schema_type,
PRIMITIVE_RULES[schema_type],
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
def _add_primitive(self, name: str, rule: BuiltinRule):
n = self._add_rule(name, rule.content)
for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f'Rule {dep} not known'
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties == True or isinstance(additional_properties, dict):
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += ' ('
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == '*':
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
else:
res = kv_rule_name
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True)
)
return res
rule += ' | '.join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
return rule
def format_grammar(self):
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
return '\n'.join(
f'{name} ::= {rule}'
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
prop_order = prop_order or []
schema = json.loads(schema)
prop_order = {name: idx for idx, name in enumerate(prop_order)}
converter = SchemaConverter(prop_order)
converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
schema = converter.resolve_refs(schema, "stdin")
converter.visit(schema, "")
return converter.format_grammar()

View file

@ -59,7 +59,16 @@ def main():
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
# Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
import json
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
else:

View file

@ -87,6 +87,13 @@ def get_llama_proxy():
llama_outer_lock.release()
_ping_message_factory = None
def set_ping_message_factory(factory):
global _ping_message_factory
_ping_message_factory = factory
def create_app(
settings: Settings | None = None,
server_settings: ServerSettings | None = None,
@ -97,7 +104,15 @@ def create_app(
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
# Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
@ -130,6 +145,9 @@ def create_app(
assert model_settings is not None
set_llama_proxy(model_settings=model_settings)
if server_settings.disable_ping_events:
set_ping_message_factory(lambda: bytes())
return app
@ -294,6 +312,7 @@ async def create_completion(
iterator=iterator(),
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
return iterator_or_completion
@ -462,6 +481,7 @@ async def create_chat_completion(
iterator=iterator(),
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
return iterator_or_completion

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import multiprocessing
from typing import Optional, List, Literal, Union
from pydantic import Field
from pydantic import Field, root_validator
from pydantic_settings import BaseSettings
import llama_cpp
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=1,
description="The number of threads to use.",
description="The number of threads to use. Use -1 for max cpu threads",
)
n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
default=max(multiprocessing.cpu_count(), 1),
ge=0,
description="The number of threads to use when batch processing.",
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
default=True, description="Whether to print debug information."
)
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
def set_dynamic_defaults(cls, values):
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
if values.get('n_threads', 0) == -1:
values['n_threads'] = cpu_count
if values.get('n_threads_batch', 0) == -1:
values['n_threads_batch'] = cpu_count
return values
class ServerSettings(BaseSettings):
"""Server settings used to configure the FastAPI and Uvicorn server."""
@ -195,6 +205,10 @@ class ServerSettings(BaseSettings):
default=True,
description="Whether to interrupt requests when a new request is received.",
)
disable_ping_events: bool = Field(
default=False,
description="Disable EventSource pings (may be needed for some clients).",
)
class Settings(ServerSettings, ModelSettings):

View file

@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core[pyproject]>=0.5.1"]
requires = ["scikit-build-core[pyproject]>=0.9.2"]
build-backend = "scikit_build_core.build"
[project]
@ -35,6 +35,7 @@ server = [
"pydantic-settings>=2.0.1",
"sse-starlette>=1.6.1",
"starlette-context>=0.3.6,<0.4",
"PyYAML>=5.1",
]
test = [
"pytest>=7.4.0",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0
Subproject commit 4e96a812b3ce7322a29a3008db2ed73d9087b176