This commit is contained in:
commit
ce85be97e2
14 changed files with 1201 additions and 191 deletions
31
.github/workflows/build-and-release.yaml
vendored
31
.github/workflows/build-and-release.yaml
vendored
|
@ -41,6 +41,35 @@ jobs:
|
|||
with:
|
||||
path: ./wheelhouse/*.whl
|
||||
|
||||
build_arm64_wheels:
|
||||
name: Build arm64 wheels
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
platforms: linux/arm64
|
||||
|
||||
- name: Build wheels
|
||||
uses: pypa/cibuildwheel@v2.16.5
|
||||
env:
|
||||
CIBW_SKIP: "*musllinux* pp*"
|
||||
CIBW_REPAIR_WHEEL_COMMAND: ""
|
||||
CIBW_ARCHS: "aarch64"
|
||||
CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*"
|
||||
with:
|
||||
output-dir: wheelhouse/
|
||||
|
||||
- name: Upload wheels as artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels-${{ matrix.version }}
|
||||
path: wheelhouse/*.whl
|
||||
|
||||
build_sdist:
|
||||
name: Build source distribution
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -65,7 +94,7 @@ jobs:
|
|||
|
||||
release:
|
||||
name: Release
|
||||
needs: [build_wheels, build_sdist]
|
||||
needs: [build_wheels, build_arm64_wheels, build_sdist]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
|
29
CHANGELOG.md
29
CHANGELOG.md
|
@ -7,6 +7,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.2.64]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@4e96a812b3ce7322a29a3008db2ed73d9087b176
|
||||
- feat: Add `llama-3` chat format by @andreabak in #1371
|
||||
- feat: Use new llama_token_is_eog in create_completions by @abetlen in d40a250ef3cfaa8224d12c83776a2f1de96ae3d1
|
||||
- feat(server): Provide ability to dynamically allocate all threads if desired using -1 by @sean-bailey in #1364
|
||||
- ci: Build arm64 wheels by @gaby in 611781f5319719a3d05fefccbbf0cc321742a026
|
||||
- fix: Update scikit-build-core build dependency avoid bug in 0.9.1 by @evelkey in #1370
|
||||
|
||||
## [0.2.63]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
|
||||
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
|
||||
|
||||
## [0.2.62]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c
|
||||
- feat: update grammar schema converter to match llama.cpp by @themrzmaster in #1353
|
||||
- feat: add disable_ping_events flag by @khimaros in #1257
|
||||
- feat: Make saved state more compact on-disk by @tc-wolf in #1296
|
||||
- feat: Use all available CPUs for batch processing by @ddh0 in #1345
|
||||
|
||||
## [0.2.61]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676
|
||||
- fix: pass correct type to chat handlers for chat completion logprobs by @abetlen in bb65b4d76411112c6fb0bf759efd746f99ef3c6b
|
||||
- feat: Add support for yaml based server configs by @abetlen in 060bfa64d529ade2af9b1f4e207a3937bbc4138f
|
||||
- feat: Add typechecking for ctypes structure attributes by @abetlen in 1347e1d050fc5a9a32ffe0bb3e22858da28003bd
|
||||
|
||||
## [0.2.60]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0
|
||||
|
|
30
examples/batch-processing/server.py
Normal file
30
examples/batch-processing/server.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
"""llama-cpp-python server from scratch in a single file.
|
||||
"""
|
||||
|
||||
# import llama_cpp
|
||||
|
||||
# path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf"
|
||||
|
||||
# model_params = llama_cpp.llama_model_default_params()
|
||||
# model = llama_cpp.llama_load_model_from_file(path, model_params)
|
||||
|
||||
# if model is None:
|
||||
# raise RuntimeError(f"Failed to load model from file: {path}")
|
||||
|
||||
|
||||
# ctx_params = llama_cpp.llama_context_default_params()
|
||||
# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params)
|
||||
|
||||
# if ctx is None:
|
||||
# raise RuntimeError("Failed to create context")
|
||||
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
import openai.types.chat as types
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def create_chat_completions():
|
||||
return {"message": "Hello World"}
|
|
@ -1,4 +1,4 @@
|
|||
from .llama_cpp import *
|
||||
from .llama import *
|
||||
|
||||
__version__ = "0.2.60"
|
||||
__version__ = "0.2.64"
|
|
@ -181,20 +181,20 @@ class _LlamaModel:
|
|||
)
|
||||
return list(tokens[:n_tokens])
|
||||
|
||||
def token_to_piece(self, token: int) -> bytes:
|
||||
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
||||
assert self.model is not None
|
||||
buf = ctypes.create_string_buffer(32)
|
||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
|
||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special)
|
||||
return bytes(buf)
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
||||
assert self.model is not None
|
||||
output = b""
|
||||
size = 32
|
||||
buffer = (ctypes.c_char * size)()
|
||||
for token in tokens:
|
||||
n = llama_cpp.llama_token_to_piece(
|
||||
self.model, llama_cpp.llama_token(token), buffer, size
|
||||
self.model, llama_cpp.llama_token(token), buffer, size, special
|
||||
)
|
||||
assert n <= size
|
||||
output += bytes(buffer[:n])
|
||||
|
@ -597,13 +597,13 @@ def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> li
|
|||
return list(result)
|
||||
|
||||
|
||||
def _token_to_piece(model: _LlamaModel, token: int) -> str:
|
||||
def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str:
|
||||
assert model.model is not None
|
||||
result = (ctypes.c_char * 8)(0)
|
||||
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
||||
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
||||
if n_tokens < 0:
|
||||
result = (ctypes.c_char * -n_tokens)(0)
|
||||
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
||||
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
||||
if check != -n_tokens:
|
||||
raise RuntimeError(f"Failed to get piece: token={token}")
|
||||
else:
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import (
|
|||
Iterator,
|
||||
Deque,
|
||||
Callable,
|
||||
Dict,
|
||||
)
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
@ -262,9 +263,7 @@ class Llama:
|
|||
|
||||
self.n_batch = min(n_ctx, n_batch) # ???
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
self.n_threads_batch = n_threads_batch or max(
|
||||
multiprocessing.cpu_count() // 2, 1
|
||||
)
|
||||
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
|
||||
|
||||
# Context Params
|
||||
self.context_params = llama_cpp.llama_context_default_params()
|
||||
|
@ -427,7 +426,10 @@ class Llama:
|
|||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||
|
||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||
template=template, eos_token=eos_token, bos_token=bos_token
|
||||
template=template,
|
||||
eos_token=eos_token,
|
||||
bos_token=bos_token,
|
||||
stop_token_ids=[eos_token_id],
|
||||
).to_chat_handler()
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None:
|
||||
|
@ -1032,7 +1034,8 @@ class Llama:
|
|||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
):
|
||||
if token == self._token_eos:
|
||||
assert self._model.model is not None
|
||||
if llama_cpp.llama_token_is_eog(self._model.model, token):
|
||||
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
@ -1664,7 +1667,8 @@ class Llama:
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
@ -1792,7 +1796,7 @@ class Llama:
|
|||
file=sys.stderr,
|
||||
)
|
||||
return LlamaState(
|
||||
scores=self.scores.copy(),
|
||||
scores=self._scores.copy(),
|
||||
input_ids=self.input_ids.copy(),
|
||||
n_tokens=self.n_tokens,
|
||||
llama_state=bytes(llama_state_compact),
|
||||
|
@ -1801,7 +1805,9 @@ class Llama:
|
|||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self._ctx.ctx is not None
|
||||
self.scores = state.scores.copy()
|
||||
# Only filling in up to `n_tokens` and then zero-ing out the rest
|
||||
self.scores[: state.n_tokens, :] = state.scores.copy()
|
||||
self.scores[state.n_tokens :, :] = 0.0
|
||||
self.input_ids = state.input_ids.copy()
|
||||
self.n_tokens = state.n_tokens
|
||||
state_size = state.llama_state_size
|
||||
|
@ -1952,7 +1958,6 @@ class Llama:
|
|||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
|
||||
)
|
||||
else:
|
||||
model_path = os.path.join(local_dir, filename)
|
||||
|
|
|
@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
|
|||
|
||||
import jinja2
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import llama_cpp.llama as llama
|
||||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama_grammar as llama_grammar
|
||||
|
@ -32,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
|
|||
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||
|
||||
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||
LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
|
||||
|
||||
### Chat Completion Handler ###
|
||||
|
||||
|
||||
|
@ -77,6 +83,8 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
mirostat_eta: float = 0.1,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -148,6 +156,7 @@ class ChatFormatterResponse:
|
|||
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
|
||||
|
||||
|
||||
class ChatFormatter(Protocol):
|
||||
|
@ -171,12 +180,14 @@ class Jinja2ChatFormatter(ChatFormatter):
|
|||
eos_token: str,
|
||||
bos_token: str,
|
||||
add_generation_prompt: bool = True,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||
self.template = template
|
||||
self.eos_token = eos_token
|
||||
self.bos_token = bos_token
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
|
||||
|
||||
self._environment = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
|
@ -209,7 +220,16 @@ class Jinja2ChatFormatter(ChatFormatter):
|
|||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
||||
stopping_criteria = None
|
||||
if self.stop_token_ids is not None:
|
||||
def stop_on_last_token(
|
||||
tokens: npt.NDArray[np.intc],
|
||||
logits: npt.NDArray[np.single]
|
||||
) -> bool:
|
||||
return tokens[-1] in self.stop_token_ids
|
||||
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
|
||||
|
||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
|
||||
|
||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||
return chat_formatter_to_chat_completion_handler(self)
|
||||
|
@ -338,7 +358,7 @@ def _convert_completion_to_chat_function(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -391,7 +411,7 @@ def _convert_completion_to_chat_function(
|
|||
{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"logprobs": chunk["choices"][0]["logprobs"],
|
||||
"delta": {
|
||||
"role": None,
|
||||
"content": None,
|
||||
|
@ -426,7 +446,7 @@ def _convert_completion_to_chat_function(
|
|||
{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"logprobs": chunk["choices"][0]["logprobs"],
|
||||
"delta": {
|
||||
"role": None,
|
||||
"content": None,
|
||||
|
@ -491,7 +511,6 @@ def chat_formatter_to_chat_completion_handler(
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
logprobs: int = 0,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
|
@ -512,6 +531,8 @@ def chat_formatter_to_chat_completion_handler(
|
|||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -530,6 +551,10 @@ def chat_formatter_to_chat_completion_handler(
|
|||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
stopping_criteria = None
|
||||
if result.stopping_criteria is not None:
|
||||
stopping_criteria = result.stopping_criteria
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
||||
|
||||
|
@ -581,7 +606,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=logprobs,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
@ -595,6 +620,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
grammar=grammar,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
@ -706,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
|
|||
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
|
||||
return "mistral-instruct"
|
||||
|
||||
if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
|
||||
return "llama-3"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -897,6 +926,26 @@ def format_llama2(
|
|||
return ChatFormatterResponse(prompt=_prompt)
|
||||
|
||||
|
||||
# Chat format for Llama-3 models, see more details at:
|
||||
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
|
||||
@register_chat_format("llama-3")
|
||||
def format_llama3(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
_roles = dict(
|
||||
system="<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
user="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
_begin_token = "<|begin_of_text|>"
|
||||
_sep = "<|eot_id|>"
|
||||
_messages = _map_roles(messages, _roles)
|
||||
_messages.append((_roles["assistant"], None))
|
||||
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||
|
||||
|
||||
@register_chat_format("alpaca")
|
||||
def format_alpaca(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
|
@ -1628,7 +1677,7 @@ def functionary_chat_handler(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -2085,7 +2134,7 @@ def functionary_v1_v2_chat_handler(
|
|||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None if content == "" else content,
|
||||
|
@ -2311,11 +2360,14 @@ def chatml_function_calling(
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
print(logprobs)
|
||||
function_calling_template = (
|
||||
"{% for message in messages %}"
|
||||
"<|im_start|>{{ message.role }}\n"
|
||||
|
@ -2437,6 +2489,7 @@ def chatml_function_calling(
|
|||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
),
|
||||
stream=stream,
|
||||
)
|
||||
|
@ -2549,6 +2602,7 @@ def chatml_function_calling(
|
|||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=["<|im_end|>"],
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
max_tokens=None,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -2660,7 +2714,7 @@ def chatml_function_calling(
|
|||
{
|
||||
"finish_reason": "tool_calls",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
|
@ -2701,4 +2755,4 @@ def chatml_function_calling(
|
|||
},
|
||||
}
|
||||
|
||||
raise ValueError("Automatic streaming tool choice is not supported")
|
||||
raise ValueError("Automatic streaming tool choice is not supported")
|
|
@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
|
|||
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
|
||||
|
||||
# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
|
||||
|
||||
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
|
||||
# define LLAMA_SESSION_VERSION 5
|
||||
LLAMA_SESSION_VERSION = 5
|
||||
|
||||
# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
|
||||
# define LLAMA_STATE_SEQ_VERSION 1
|
||||
LLAMA_STATE_SEQ_VERSION = 1
|
||||
|
||||
# struct llama_model;
|
||||
llama_model_p = NewType("llama_model_p", int)
|
||||
|
@ -424,6 +431,11 @@ class llama_token_data(ctypes.Structure):
|
|||
logit (float): log-odds of the token
|
||||
p (float): probability of the token"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
id: llama_token
|
||||
logit: float
|
||||
p: float
|
||||
|
||||
_fields_ = [
|
||||
("id", llama_token),
|
||||
("logit", ctypes.c_float),
|
||||
|
@ -447,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
|
|||
size (int): size of the array
|
||||
sorted (bool): whether the array is sorted"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
data: CtypesArray[llama_token_data]
|
||||
size: int
|
||||
sorted: bool
|
||||
|
||||
_fields_ = [
|
||||
("data", llama_token_data_p),
|
||||
("size", ctypes.c_size_t),
|
||||
|
@ -508,6 +525,15 @@ class llama_batch(ctypes.Structure):
|
|||
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
n_tokens: int
|
||||
token: CtypesArray[llama_token]
|
||||
embd: CtypesArray[ctypes.c_float]
|
||||
pos: CtypesArray[CtypesArray[llama_pos]]
|
||||
n_seq_id: CtypesArray[ctypes.c_int]
|
||||
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
|
||||
logits: CtypesArray[ctypes.c_int8]
|
||||
|
||||
_fields_ = [
|
||||
("n_tokens", ctypes.c_int32),
|
||||
("token", ctypes.POINTER(llama_token)),
|
||||
|
@ -602,6 +628,18 @@ class llama_model_params(ctypes.Structure):
|
|||
use_mmap (bool): use mmap if possible
|
||||
use_mlock (bool): force system to keep model in RAM"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
n_gpu_layers: int
|
||||
split_mode: int
|
||||
main_gpu: int
|
||||
tensor_split: CtypesArray[ctypes.c_float]
|
||||
progress_callback: Callable[[float, ctypes.c_void_p], bool]
|
||||
progress_callback_user_data: ctypes.c_void_p
|
||||
kv_overrides: CtypesArray[llama_model_kv_override]
|
||||
vocab_only: bool
|
||||
use_mmap: bool
|
||||
use_mlock: bool
|
||||
|
||||
_fields_ = [
|
||||
("n_gpu_layers", ctypes.c_int32),
|
||||
("split_mode", ctypes.c_int),
|
||||
|
@ -689,6 +727,34 @@ class llama_context_params(ctypes.Structure):
|
|||
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
seed: int
|
||||
n_ctx: int
|
||||
n_batch: int
|
||||
n_ubatch: int
|
||||
n_seq_max: int
|
||||
n_threads: int
|
||||
n_threads_batch: int
|
||||
rope_scaling_type: int
|
||||
pooling_type: int
|
||||
rope_freq_base: float
|
||||
rope_freq_scale: float
|
||||
yarn_ext_factor: float
|
||||
yarn_attn_factor: float
|
||||
yarn_beta_fast: float
|
||||
yarn_beta_slow: float
|
||||
yarn_orig_ctx: int
|
||||
defrag_thold: float
|
||||
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
|
||||
cb_eval_user_data: ctypes.c_void_p
|
||||
type_k: int
|
||||
type_v: int
|
||||
logits_all: bool
|
||||
embeddings: bool
|
||||
offload_kqv: bool
|
||||
abort_callback: Callable[[ctypes.c_void_p], bool]
|
||||
abort_callback_data: ctypes.c_void_p
|
||||
|
||||
_fields_ = [
|
||||
("seed", ctypes.c_uint32),
|
||||
("n_ctx", ctypes.c_uint32),
|
||||
|
@ -764,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
|
|||
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
nthread: int
|
||||
ftype: int
|
||||
output_tensor_type: int
|
||||
token_embedding_type: int
|
||||
allow_requantize: bool
|
||||
quantize_output_tensor: bool
|
||||
only_copy: bool
|
||||
pure: bool
|
||||
imatrix: ctypes.c_void_p
|
||||
kv_overrides: ctypes.c_void_p
|
||||
|
||||
_fields_ = [
|
||||
("nthread", ctypes.c_int32),
|
||||
("ftype", ctypes.c_int),
|
||||
|
@ -821,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
|
|||
# uint32_t value; // Unicode code point or rule ID
|
||||
# } llama_grammar_element;
|
||||
class llama_grammar_element(ctypes.Structure):
|
||||
if TYPE_CHECKING:
|
||||
type: int
|
||||
value: int
|
||||
|
||||
_fields_ = [
|
||||
("type", ctypes.c_int),
|
||||
("value", ctypes.c_uint32),
|
||||
|
@ -844,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
|
|||
# int32_t n_eval;
|
||||
# };
|
||||
class llama_timings(ctypes.Structure):
|
||||
if TYPE_CHECKING:
|
||||
t_start_ms: float
|
||||
t_end_ms: float
|
||||
t_load_ms: float
|
||||
t_sample_ms: float
|
||||
t_p_eval_ms: float
|
||||
t_eval_ms: float
|
||||
n_sample: int
|
||||
n_p_eval: int
|
||||
n_eval: int
|
||||
|
||||
_fields_ = [
|
||||
("t_start_ms", ctypes.c_double),
|
||||
("t_end_ms", ctypes.c_double),
|
||||
|
@ -944,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
|
|||
[ctypes.c_int],
|
||||
None,
|
||||
)
|
||||
def llama_numa_init(numa: int, /): ...
|
||||
def llama_numa_init(numa: int, /):
|
||||
...
|
||||
|
||||
|
||||
# // Call once at the end of the program - currently only used for MPI
|
||||
|
@ -969,7 +1063,8 @@ def llama_backend_free():
|
|||
)
|
||||
def llama_load_model_from_file(
|
||||
path_model: bytes, params: llama_model_params, /
|
||||
) -> Optional[llama_model_p]: ...
|
||||
) -> Optional[llama_model_p]:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API void llama_free_model(struct llama_model * model);
|
||||
|
@ -978,7 +1073,8 @@ def llama_load_model_from_file(
|
|||
[llama_model_p_ctypes],
|
||||
None,
|
||||
)
|
||||
def llama_free_model(model: llama_model_p, /): ...
|
||||
def llama_free_model(model: llama_model_p, /):
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
|
@ -991,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
|
|||
)
|
||||
def llama_new_context_with_model(
|
||||
model: llama_model_p, params: llama_context_params, /
|
||||
) -> Optional[llama_context_p]: ...
|
||||
) -> Optional[llama_context_p]:
|
||||
...
|
||||
|
||||
|
||||
# // Frees all allocated memory
|
||||
|
@ -1012,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
|
|||
[],
|
||||
ctypes.c_int64,
|
||||
)
|
||||
def llama_time_us() -> int: ...
|
||||
def llama_time_us() -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API size_t llama_max_devices(void);
|
||||
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
|
||||
def llama_max_devices() -> int: ...
|
||||
def llama_max_devices() -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_mmap (void);
|
||||
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
|
||||
def llama_supports_mmap() -> bool: ...
|
||||
def llama_supports_mmap() -> bool:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_mlock (void);
|
||||
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
|
||||
def llama_supports_mlock() -> bool: ...
|
||||
def llama_supports_mlock() -> bool:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
|
||||
def llama_supports_gpu_offload() -> bool: ...
|
||||
def llama_supports_gpu_offload() -> bool:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
|
||||
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
|
||||
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
|
||||
def llama_n_ctx(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
|
||||
def llama_n_batch(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
|
||||
def llama_n_ubatch(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
|
||||
def llama_n_seq_max(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||
def llama_vocab_type(model: llama_model_p, /) -> int: ...
|
||||
def llama_vocab_type(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||
def llama_rope_type(model: llama_model_p, /) -> int: ...
|
||||
def llama_rope_type(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_vocab(model: llama_model_p, /) -> int: ...
|
||||
def llama_n_vocab(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
|
||||
def llama_n_ctx_train(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_embd(model: llama_model_p, /) -> int: ...
|
||||
def llama_n_embd(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_layer(model: llama_model_p, /) -> int: ...
|
||||
def llama_n_layer(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# // Get the model's RoPE frequency scaling factor
|
||||
|
@ -1351,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
|||
pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
|
||||
May be negative if the cell is not populated."""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pos: llama_pos
|
||||
|
||||
_fields_ = [("pos", llama_pos)]
|
||||
|
||||
|
||||
|
@ -1387,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
|||
# llama_seq_id * cells_sequences;
|
||||
# };
|
||||
class llama_kv_cache_view(ctypes.Structure):
|
||||
if TYPE_CHECKING:
|
||||
n_cells: int
|
||||
n_max_seq: int
|
||||
token_count: int
|
||||
used_cells: int
|
||||
max_contiguous: int
|
||||
max_contiguous_idx: int
|
||||
cells: CtypesArray[llama_kv_cache_view_cell]
|
||||
cells_sequences: CtypesArray[llama_seq_id]
|
||||
|
||||
_fields_ = [
|
||||
("n_cells", ctypes.c_int32),
|
||||
("n_max_seq", ctypes.c_int32),
|
||||
|
@ -1467,6 +1593,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
|
|||
|
||||
|
||||
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
# // seq_id < 0 : match any sequence
|
||||
# // p0 < 0 : [0, p1]
|
||||
# // p1 < 0 : [p0, inf)
|
||||
|
@ -1493,6 +1620,9 @@ def llama_kv_cache_seq_rm(
|
|||
/,
|
||||
) -> bool:
|
||||
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
|
||||
Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
|
||||
seq_id < 0 : match any sequence
|
||||
p0 < 0 : [0, p1]
|
||||
p1 < 0 : [p0, inf)"""
|
||||
|
@ -1652,7 +1782,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
|
|||
|
||||
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
# and kv_cache) - will often be smaller after compacting tokens
|
||||
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
||||
# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
||||
@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t)
|
||||
def llama_state_get_size(ctx: llama_context_p, /) -> int:
|
||||
"""Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
and kv_cache) - will often be smaller after compacting tokens"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
|
||||
# "use llama_state_get_size instead");
|
||||
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
|
||||
def llama_get_state_size(ctx: llama_context_p, /) -> int:
|
||||
"""Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
|
@ -1663,9 +1802,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
|
|||
# Copies the state to the specified destination address.
|
||||
# Destination needs to have allocated enough memory.
|
||||
# Returns the number of bytes copied
|
||||
# LLAMA_API size_t llama_copy_state_data(
|
||||
# LLAMA_API size_t llama_state_get_data(
|
||||
# struct llama_context * ctx,
|
||||
# uint8_t * dst);
|
||||
@ctypes_function(
|
||||
"llama_state_get_data",
|
||||
[
|
||||
llama_context_p_ctypes,
|
||||
ctypes.POINTER(ctypes.c_uint8),
|
||||
],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_get_data(
|
||||
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
|
||||
) -> int:
|
||||
"""Copies the state to the specified destination address.
|
||||
Destination needs to have allocated enough memory.
|
||||
Returns the number of bytes copied"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
||||
# struct llama_context * ctx,
|
||||
# uint8_t * dst),
|
||||
# "use llama_state_get_data instead");
|
||||
@ctypes_function(
|
||||
"llama_copy_state_data",
|
||||
[
|
||||
|
@ -1685,9 +1845,26 @@ def llama_copy_state_data(
|
|||
|
||||
# // Set the state reading from the specified address
|
||||
# // Returns the number of bytes read
|
||||
# LLAMA_API size_t llama_set_state_data(
|
||||
# LLAMA_API size_t llama_state_set_data(
|
||||
# struct llama_context * ctx,
|
||||
# const uint8_t * src);
|
||||
@ctypes_function(
|
||||
"llama_state_set_data",
|
||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_set_data(
|
||||
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
|
||||
) -> int:
|
||||
"""Set the state reading from the specified address
|
||||
Returns the number of bytes read"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
||||
# struct llama_context * ctx,
|
||||
# const uint8_t * src),
|
||||
# "use llama_state_set_data instead");
|
||||
@ctypes_function(
|
||||
"llama_set_state_data",
|
||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
||||
|
@ -1701,12 +1878,41 @@ def llama_set_state_data(
|
|||
|
||||
|
||||
# Save/load session file
|
||||
# LLAMA_API bool llama_load_session_file(
|
||||
# LLAMA_API bool llama_state_load_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * path_session,
|
||||
# llama_token * tokens_out,
|
||||
# size_t n_token_capacity,
|
||||
# size_t * n_token_count_out);
|
||||
@ctypes_function(
|
||||
"llama_state_load_file",
|
||||
[
|
||||
llama_context_p_ctypes,
|
||||
ctypes.c_char_p,
|
||||
llama_token_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.POINTER(ctypes.c_size_t),
|
||||
],
|
||||
ctypes.c_bool,
|
||||
)
|
||||
def llama_state_load_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens_out: CtypesArray[llama_token],
|
||||
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||
/,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API DEPRECATED(bool llama_load_session_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * path_session,
|
||||
# llama_token * tokens_out,
|
||||
# size_t n_token_capacity,
|
||||
# size_t * n_token_count_out),
|
||||
# "use llama_state_load_file instead");
|
||||
@ctypes_function(
|
||||
"llama_load_session_file",
|
||||
[
|
||||
|
@ -1725,14 +1931,41 @@ def llama_load_session_file(
|
|||
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||
/,
|
||||
) -> int: ...
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_save_session_file(
|
||||
# LLAMA_API bool llama_state_save_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * path_session,
|
||||
# const llama_token * tokens,
|
||||
# size_t n_token_count);
|
||||
@ctypes_function(
|
||||
"llama_state_save_file",
|
||||
[
|
||||
llama_context_p_ctypes,
|
||||
ctypes.c_char_p,
|
||||
llama_token_p,
|
||||
ctypes.c_size_t,
|
||||
],
|
||||
ctypes.c_bool,
|
||||
)
|
||||
def llama_state_save_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_token_count: Union[ctypes.c_size_t, int],
|
||||
/,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API DEPRECATED(bool llama_save_session_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * path_session,
|
||||
# const llama_token * tokens,
|
||||
# size_t n_token_count),
|
||||
# "use llama_state_save_file instead");
|
||||
@ctypes_function(
|
||||
"llama_save_session_file",
|
||||
[
|
||||
|
@ -1749,7 +1982,118 @@ def llama_save_session_file(
|
|||
tokens: CtypesArray[llama_token],
|
||||
n_token_count: Union[ctypes.c_size_t, int],
|
||||
/,
|
||||
) -> int: ...
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# // Get the exact size needed to copy the KV cache of a single sequence
|
||||
# LLAMA_API size_t llama_state_seq_get_size(
|
||||
# struct llama_context * ctx,
|
||||
# llama_seq_id seq_id);
|
||||
@ctypes_function(
|
||||
"llama_state_seq_get_size",
|
||||
[llama_context_p_ctypes, llama_seq_id],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
|
||||
"""Get the exact size needed to copy the KV cache of a single sequence"""
|
||||
...
|
||||
|
||||
|
||||
# // Copy the KV cache of a single sequence into the specified buffer
|
||||
# LLAMA_API size_t llama_state_seq_get_data(
|
||||
# struct llama_context * ctx,
|
||||
# uint8_t * dst,
|
||||
# llama_seq_id seq_id);
|
||||
@ctypes_function(
|
||||
"llama_state_seq_get_data",
|
||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_seq_get_data(
|
||||
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, /
|
||||
) -> int:
|
||||
"""Copy the KV cache of a single sequence into the specified buffer"""
|
||||
...
|
||||
|
||||
|
||||
# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||
# // Returns:
|
||||
# // - Positive: Ok
|
||||
# // - Zero: Failed to load
|
||||
# LLAMA_API size_t llama_state_seq_set_data(
|
||||
# struct llama_context * ctx,
|
||||
# const uint8_t * src,
|
||||
# llama_seq_id dest_seq_id);
|
||||
@ctypes_function(
|
||||
"llama_state_seq_set_data",
|
||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_seq_set_data(
|
||||
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, /
|
||||
) -> int:
|
||||
"""Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API size_t llama_state_seq_save_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * filepath,
|
||||
# llama_seq_id seq_id,
|
||||
# const llama_token * tokens,
|
||||
# size_t n_token_count);
|
||||
@ctypes_function(
|
||||
"llama_state_seq_save_file",
|
||||
[
|
||||
llama_context_p_ctypes,
|
||||
ctypes.c_char_p,
|
||||
llama_seq_id,
|
||||
llama_token_p,
|
||||
ctypes.c_size_t,
|
||||
],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_seq_save_file(
|
||||
ctx: llama_context_p,
|
||||
filepath: bytes,
|
||||
seq_id: llama_seq_id,
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_token_count: Union[ctypes.c_size_t, int],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API size_t llama_state_seq_load_file(
|
||||
# struct llama_context * ctx,
|
||||
# const char * filepath,
|
||||
# llama_seq_id dest_seq_id,
|
||||
# llama_token * tokens_out,
|
||||
# size_t n_token_capacity,
|
||||
# size_t * n_token_count_out);
|
||||
@ctypes_function(
|
||||
"llama_state_seq_load_file",
|
||||
[
|
||||
llama_context_p_ctypes,
|
||||
ctypes.c_char_p,
|
||||
llama_seq_id,
|
||||
llama_token_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.POINTER(ctypes.c_size_t),
|
||||
],
|
||||
ctypes.c_size_t,
|
||||
)
|
||||
def llama_state_seq_load_file(
|
||||
ctx: llama_context_p,
|
||||
filepath: bytes,
|
||||
dest_seq_id: llama_seq_id,
|
||||
tokens_out: CtypesArray[llama_token],
|
||||
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -1930,8 +2274,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
|
|||
...
|
||||
|
||||
|
||||
# // Logits for the ith token. Equivalent to:
|
||||
# // Logits for the ith token. For positive indices, Equivalent to:
|
||||
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||
# // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
|
||||
# // returns NULL for invalid ids.
|
||||
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
@ctypes_function(
|
||||
|
@ -1963,8 +2308,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
|
|||
...
|
||||
|
||||
|
||||
# // Get the embeddings for the ith token. Equivalent to:
|
||||
# // Get the embeddings for the ith token. For positive indices, Equivalent to:
|
||||
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||
# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
|
||||
# // shape: [n_embd] (1-dimensional)
|
||||
# // returns NULL for invalid ids.
|
||||
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
|
@ -2010,7 +2356,8 @@ def llama_get_embeddings_seq(
|
|||
)
|
||||
def llama_token_get_text(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> bytes: ...
|
||||
) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
||||
|
@ -2019,7 +2366,8 @@ def llama_token_get_text(
|
|||
)
|
||||
def llama_token_get_score(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> float: ...
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
||||
|
@ -2028,7 +2376,20 @@ def llama_token_get_score(
|
|||
)
|
||||
def llama_token_get_type(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> int: ...
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
||||
# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
||||
@ctypes_function(
|
||||
"llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool
|
||||
)
|
||||
def llama_token_is_eog(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> bool:
|
||||
"""Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)"""
|
||||
...
|
||||
|
||||
|
||||
# // Special tokens
|
||||
|
@ -2048,6 +2409,20 @@ def llama_token_eos(model: llama_model_p, /) -> int:
|
|||
...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
||||
@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_cls(model: llama_model_p, /) -> int:
|
||||
"""classification"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
||||
@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_sep(model: llama_model_p, /) -> int:
|
||||
"""sentence separator"""
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_nl(model: llama_model_p, /) -> int:
|
||||
|
@ -2071,7 +2446,7 @@ def llama_add_eos_token(model: llama_model_p, /) -> int:
|
|||
...
|
||||
|
||||
|
||||
# // codellama infill tokens
|
||||
# // Codellama infill tokens
|
||||
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
||||
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_prefix(model: llama_model_p) -> int:
|
||||
|
@ -2081,17 +2456,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
|
|||
|
||||
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
||||
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_middle(model: llama_model_p, /) -> int: ...
|
||||
def llama_token_middle(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
||||
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_suffix(model: llama_model_p, /) -> int: ...
|
||||
def llama_token_suffix(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
|
||||
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_eot(model: llama_model_p, /) -> int: ...
|
||||
def llama_token_eot(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -2103,16 +2481,16 @@ def llama_token_eot(model: llama_model_p, /) -> int: ...
|
|||
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||
# /// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
||||
# /// Does not insert a leading space.
|
||||
# /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||
# /// as plaintext. Does not insert a leading space.
|
||||
# LLAMA_API int32_t llama_tokenize(
|
||||
# const struct llama_model * model,
|
||||
# const char * text,
|
||||
# int32_t text_len,
|
||||
# llama_token * tokens,
|
||||
# int32_t n_tokens_max,
|
||||
# bool add_bos,
|
||||
# bool special);
|
||||
# bool add_special,
|
||||
# bool parse_special);
|
||||
@ctypes_function(
|
||||
"llama_tokenize",
|
||||
[
|
||||
|
@ -2132,8 +2510,8 @@ def llama_tokenize(
|
|||
text_len: Union[ctypes.c_int, int],
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_tokens_max: Union[ctypes.c_int, int],
|
||||
add_bos: Union[ctypes.c_bool, bool],
|
||||
special: Union[ctypes.c_bool, bool],
|
||||
add_special: Union[ctypes.c_bool, bool],
|
||||
parse_special: Union[ctypes.c_bool, bool],
|
||||
/,
|
||||
) -> int:
|
||||
"""Convert the provided text into tokens.
|
||||
|
@ -2144,9 +2522,8 @@ def llama_tokenize(
|
|||
text_len: The length of the text.
|
||||
tokens: The tokens pointer must be large enough to hold the resulting tokens.
|
||||
n_max_tokens: The maximum number of tokens to return.
|
||||
add_bos: Whether to add a beginning-of-sentence token.
|
||||
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
||||
Does not insert a leading space.
|
||||
add_special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
|
||||
parse_special: Allow parsing special tokens.
|
||||
|
||||
Returns:
|
||||
Returns the number of tokens on success, no more than n_tokens_max
|
||||
|
@ -2159,11 +2536,13 @@ def llama_tokenize(
|
|||
# // Uses the vocabulary in the provided context.
|
||||
# // Does not write null terminator to the buffer.
|
||||
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||
# // @param special If true, special tokens are rendered in the output.
|
||||
# LLAMA_API int32_t llama_token_to_piece(
|
||||
# const struct llama_model * model,
|
||||
# llama_token token,
|
||||
# char * buf,
|
||||
# int32_t length);
|
||||
# int32_t length,
|
||||
# bool special);
|
||||
@ctypes_function(
|
||||
"llama_token_to_piece",
|
||||
[
|
||||
|
@ -2171,6 +2550,7 @@ def llama_tokenize(
|
|||
llama_token,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_bool,
|
||||
],
|
||||
ctypes.c_int32,
|
||||
)
|
||||
|
@ -2179,13 +2559,20 @@ def llama_token_to_piece(
|
|||
token: Union[llama_token, int],
|
||||
buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
|
||||
length: Union[ctypes.c_int, int],
|
||||
special: Union[ctypes.c_bool, bool],
|
||||
/,
|
||||
) -> int:
|
||||
"""Token Id -> Piece.
|
||||
Uses the vocabulary in the provided context.
|
||||
Does not write null terminator to the buffer.
|
||||
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: The model to use for tokenization.
|
||||
token: The token to convert.
|
||||
buf: The buffer to write the token to.
|
||||
length: The length of the buffer.
|
||||
special: If true, special tokens are rendered in the output."""
|
||||
...
|
||||
|
||||
|
||||
|
@ -2223,7 +2610,8 @@ def llama_chat_apply_template(
|
|||
chat: CtypesArray[llama_chat_message],
|
||||
n_msg: int,
|
||||
/,
|
||||
) -> int: ...
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -2753,6 +3141,12 @@ def llama_grammar_accept_token(
|
|||
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||
# };
|
||||
class llama_beam_view(ctypes.Structure):
|
||||
if TYPE_CHECKING:
|
||||
tokens: CtypesArray[llama_token]
|
||||
n_tokens: int
|
||||
p: float
|
||||
eob: bool
|
||||
|
||||
_fields_ = [
|
||||
("tokens", llama_token_p),
|
||||
("n_tokens", ctypes.c_size_t),
|
||||
|
@ -2772,6 +3166,12 @@ class llama_beam_view(ctypes.Structure):
|
|||
# bool last_call; // True iff this is the last callback invocation.
|
||||
# };
|
||||
class llama_beams_state(ctypes.Structure):
|
||||
if TYPE_CHECKING:
|
||||
beam_views: CtypesArray[llama_beam_view]
|
||||
n_beams: int
|
||||
common_prefix_length: int
|
||||
last_call: bool
|
||||
|
||||
_fields_ = [
|
||||
("beam_views", ctypes.POINTER(llama_beam_view)),
|
||||
("n_beams", ctypes.c_size_t),
|
||||
|
@ -2824,7 +3224,8 @@ def llama_beam_search(
|
|||
n_past: Union[ctypes.c_int, int],
|
||||
n_predict: Union[ctypes.c_int, int],
|
||||
/,
|
||||
): ...
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
# /// @details Build a split GGUF final path for this chunk.
|
||||
|
@ -2943,4 +3344,5 @@ def llama_log_set(
|
|||
[ctypes.c_void_p, llama_context_p_ctypes],
|
||||
None,
|
||||
)
|
||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
|
||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
|
||||
...
|
||||
|
|
|
@ -5,11 +5,12 @@ from pathlib import Path
|
|||
import sys
|
||||
from ctypes import * # type: ignore
|
||||
from enum import Enum
|
||||
from itertools import islice
|
||||
from itertools import islice, groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -1391,145 +1392,561 @@ from typing import List, Optional
|
|||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
"boolean": '("true" | "false") space',
|
||||
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
||||
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
||||
"string": r""" "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space """,
|
||||
"null": '"null" space',
|
||||
}
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
|
||||
|
||||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
|
||||
if not separator_rule:
|
||||
if min_items == 0 and max_items == 1:
|
||||
return f'{item_rule}?'
|
||||
elif min_items == 1 and max_items is None:
|
||||
return f'{item_rule}+'
|
||||
|
||||
result = ''
|
||||
|
||||
if min_items > 0:
|
||||
if item_rule_is_literal and separator_rule is None:
|
||||
result = '"' + (item_rule[1:-1] * min_items) + '"'
|
||||
else:
|
||||
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
|
||||
|
||||
def opt_repetitions(up_to_n, prefix_with_sep=False):
|
||||
'''
|
||||
- n=4, no sep: '(a (a (a (a)?)?)?)?'
|
||||
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
|
||||
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
|
||||
'''
|
||||
|
||||
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
|
||||
if up_to_n == 0:
|
||||
return ''
|
||||
elif up_to_n == 1:
|
||||
return f'({content})?'
|
||||
elif separator_rule and not prefix_with_sep:
|
||||
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
|
||||
else:
|
||||
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
|
||||
|
||||
if min_items > 0 and max_items != min_items:
|
||||
result += ' '
|
||||
|
||||
if max_items is not None:
|
||||
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
|
||||
else:
|
||||
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
|
||||
|
||||
if min_items == 0 and separator_rule:
|
||||
result = f'({item_rule} {item_operator}*)?'
|
||||
else:
|
||||
result += f'{item_operator}*'
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
|
||||
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
|
||||
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
DOT = '[^\\x0A\\x0D]'
|
||||
|
||||
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
|
||||
|
||||
|
||||
|
||||
|
||||
class SchemaConverter:
|
||||
def __init__(self, prop_order):
|
||||
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||
self._prop_order = prop_order
|
||||
self._rules = {"space": SPACE_RULE}
|
||||
self._defs: Dict[str, Any] = {}
|
||||
self._allow_fetch = allow_fetch
|
||||
self._dotall = dotall
|
||||
self._raw_pattern = raw_pattern
|
||||
self._rules = {
|
||||
'space': SPACE_RULE,
|
||||
}
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
|
||||
def _format_literal(self, literal: str):
|
||||
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
def _add_rule(self, name: str, rule: str):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
||||
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
|
||||
'''
|
||||
not_literal('a') -> '[^a]'
|
||||
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||
'''
|
||||
assert len(literal) > 0, 'Empty literal not supported'
|
||||
def recurse(i: int):
|
||||
c = literal[i]
|
||||
if maybe_escaped_underscores and c == '_':
|
||||
yield f'[^{c}\\\\]'
|
||||
yield ' | '
|
||||
yield f'"\\\\"? "{c}"'
|
||||
else:
|
||||
yield f'[^{c}]'
|
||||
if i < len(literal) - 1:
|
||||
yield ' | '
|
||||
yield self._format_literal(c)
|
||||
yield ' ('
|
||||
yield from recurse(i + 1)
|
||||
yield ')?'
|
||||
|
||||
return ''.join(('(', *recurse(0), ')'))
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
|
||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||
key = esc_name
|
||||
else:
|
||||
i = 0
|
||||
while f"{esc_name}{i}" in self._rules:
|
||||
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
|
||||
i += 1
|
||||
key = f"{esc_name}{i}"
|
||||
key = f'{esc_name}{i}'
|
||||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
||||
rule_name = name or "root"
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
ref = n.get('$ref')
|
||||
if ref is not None and ref not in self._refs:
|
||||
if ref.startswith('https://'):
|
||||
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||
import requests
|
||||
|
||||
if "$defs" in schema:
|
||||
# add defs to self._defs for later inlining
|
||||
for def_name, def_schema in schema["$defs"].items():
|
||||
self._defs[def_name] = def_schema
|
||||
frag_split = ref.split('#')
|
||||
base_url = frag_split[0]
|
||||
|
||||
if "oneOf" in schema or "anyOf" in schema:
|
||||
rule = " | ".join(
|
||||
(
|
||||
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
||||
for i, alt_schema in enumerate(
|
||||
schema.get("oneOf") or schema["anyOf"]
|
||||
)
|
||||
)
|
||||
target = self._refs.get(base_url)
|
||||
if target is None:
|
||||
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||
self._refs[base_url] = target
|
||||
|
||||
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||
return target
|
||||
elif ref.startswith('#/'):
|
||||
target = schema
|
||||
ref = f'{url}{ref}'
|
||||
n['$ref'] = ref
|
||||
else:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
for v in n.values():
|
||||
visit(v)
|
||||
|
||||
return n
|
||||
return visit(schema)
|
||||
|
||||
def _generate_union_rule(self, name, alt_schemas):
|
||||
return ' | '.join((
|
||||
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||
for i, alt_schema in enumerate(alt_schemas)
|
||||
))
|
||||
|
||||
def _visit_pattern(self, pattern, name):
|
||||
'''
|
||||
Transforms a regular expression pattern into a GBNF rule.
|
||||
|
||||
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
|
||||
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||
|
||||
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||
we define sub-rules to keep the output lean.
|
||||
'''
|
||||
|
||||
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||
pattern = pattern[1:-1]
|
||||
sub_rule_ids = {}
|
||||
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
nonlocal i
|
||||
nonlocal pattern
|
||||
nonlocal sub_rule_ids
|
||||
|
||||
start = i
|
||||
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
rule = DOTALL
|
||||
else:
|
||||
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||
rule = DOT
|
||||
return self._add_rule(f'dot', rule)
|
||||
|
||||
def join_seq():
|
||||
nonlocal seq
|
||||
ret = []
|
||||
for is_literal, g in groupby(seq, lambda x: x[1]):
|
||||
if is_literal:
|
||||
ret.append((''.join(x[0] for x in g), True))
|
||||
else:
|
||||
ret.extend(g)
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
return (' '.join(to_rule(x) for x in seq), False)
|
||||
|
||||
while i < length:
|
||||
c = pattern[i]
|
||||
if c == '.':
|
||||
seq.append((get_dot(), False))
|
||||
i += 1
|
||||
elif c == '(':
|
||||
i += 1
|
||||
if i < length:
|
||||
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||
seq.append((f'({to_rule(transform())})', False))
|
||||
elif c == ')':
|
||||
i += 1
|
||||
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
|
||||
return join_seq()
|
||||
elif c == '[':
|
||||
square_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != ']':
|
||||
if pattern[i] == '\\':
|
||||
square_brackets += pattern[i:i+2]
|
||||
i += 2
|
||||
else:
|
||||
square_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
square_brackets += ']'
|
||||
i += 1
|
||||
seq.append((square_brackets, False))
|
||||
elif c == '|':
|
||||
seq.append(('|', False))
|
||||
i += 1
|
||||
elif c in ('*', '+', '?'):
|
||||
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||
i += 1
|
||||
elif c == '{':
|
||||
curly_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != '}':
|
||||
curly_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
curly_brackets += '}'
|
||||
i += 1
|
||||
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
|
||||
min_times = 0
|
||||
max_times = None
|
||||
try:
|
||||
if len(nums) == 1:
|
||||
min_times = int(nums[0])
|
||||
max_times = min_times
|
||||
else:
|
||||
assert len(nums) == 2
|
||||
min_times = int(nums[0]) if nums[0] else 0
|
||||
max_times = int(nums[1]) if nums[1] else None
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
|
||||
|
||||
(sub, sub_is_literal) = seq[-1]
|
||||
|
||||
if not sub_is_literal:
|
||||
id = sub_rule_ids.get(sub)
|
||||
if id is None:
|
||||
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
|
||||
sub_rule_ids[sub] = id
|
||||
sub = id
|
||||
|
||||
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
|
||||
else:
|
||||
literal = ''
|
||||
while i < length:
|
||||
if pattern[i] == '\\' and i < length - 1:
|
||||
next = pattern[i + 1]
|
||||
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||
i += 1
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
literal += pattern[i:i+2]
|
||||
i += 2
|
||||
elif pattern[i] == '"' and not self._raw_pattern:
|
||||
literal += '\\"'
|
||||
i += 1
|
||||
elif pattern[i] not in NON_LITERAL_SET and \
|
||||
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
if literal:
|
||||
seq.append((literal, True))
|
||||
|
||||
return join_seq()
|
||||
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split('/')[-1]
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
ref_name = self.visit(resolved, ref_name)
|
||||
self._refs_being_resolved.remove(ref)
|
||||
return ref_name
|
||||
|
||||
def _generate_constant_rule(self, value):
|
||||
return self._format_literal(json.dumps(value))
|
||||
|
||||
def visit(self, schema, name):
|
||||
schema_type = schema.get('type')
|
||||
schema_format = schema.get('format')
|
||||
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
|
||||
|
||||
if (ref := schema.get('$ref')) is not None:
|
||||
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||
|
||||
elif 'oneOf' in schema or 'anyOf' in schema:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
('properties' in schema or \
|
||||
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
|
||||
required = set(schema.get('required', []))
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
comp_schema = self._refs[ref]
|
||||
|
||||
if 'properties' in comp_schema:
|
||||
for prop_name, prop_schema in comp_schema['properties'].items():
|
||||
properties.append((prop_name, prop_schema))
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
add_component(tt, is_required=False)
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
'"[" space ' +
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
|
||||
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
|
||||
return self._add_primitive(
|
||||
'root' if rule_name == 'root' else schema_format,
|
||||
PRIMITIVE_RULES['uuid']
|
||||
)
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif "const" in schema:
|
||||
return self._add_rule(rule_name, self._format_literal(schema["const"]))
|
||||
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
|
||||
prim_name = f'{schema_format}-string'
|
||||
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
|
||||
|
||||
elif "enum" in schema:
|
||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
||||
return self._add_rule(rule_name, rule)
|
||||
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
|
||||
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
elif "$ref" in schema:
|
||||
ref = schema["$ref"]
|
||||
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
||||
# inline $defs
|
||||
def_name = ref[len("#/$defs/") :]
|
||||
def_schema = self._defs[def_name]
|
||||
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
|
||||
|
||||
schema_type: Optional[str] = schema.get("type") # type: ignore
|
||||
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
|
||||
|
||||
if schema_type == "object" and "properties" in schema:
|
||||
# TODO: `required` keyword
|
||||
if self._prop_order:
|
||||
prop_order = self._prop_order
|
||||
prop_pairs = sorted(
|
||||
schema["properties"].items(),
|
||||
# sort by position in prop_order (if specified) then by key
|
||||
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
||||
)
|
||||
else:
|
||||
prop_pairs = schema["properties"].items()
|
||||
|
||||
rule = '"{" space'
|
||||
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
||||
prop_rule_name = self.visit(
|
||||
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
||||
)
|
||||
if i > 0:
|
||||
rule += ' "," space'
|
||||
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
||||
rule += ' "}" space'
|
||||
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type == "array" and "items" in schema:
|
||||
# TODO `prefixItems` keyword
|
||||
item_rule_name = self.visit(
|
||||
schema["items"], f'{name}{"-" if name else ""}item'
|
||||
)
|
||||
list_item_operator = f'("," space {item_rule_name})'
|
||||
successive_items = ""
|
||||
min_items = schema.get("minItems", 0)
|
||||
if min_items > 0:
|
||||
first_item = f"({item_rule_name})"
|
||||
successive_items = list_item_operator * (min_items - 1)
|
||||
min_items -= 1
|
||||
else:
|
||||
first_item = f"({item_rule_name})?"
|
||||
max_items = schema.get("maxItems")
|
||||
if max_items is not None and max_items > min_items:
|
||||
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
|
||||
else:
|
||||
successive_items += list_item_operator + "*"
|
||||
rule = f'"[" space {first_item} {successive_items} "]" space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||
return self._add_rule(
|
||||
"root" if rule_name == "root" else schema_type,
|
||||
PRIMITIVE_RULES[schema_type],
|
||||
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
|
||||
|
||||
def _add_primitive(self, name: str, rule: BuiltinRule):
|
||||
n = self._add_rule(name, rule.content)
|
||||
|
||||
for dep in rule.deps:
|
||||
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
||||
assert dep_rule, f'Rule {dep} not known'
|
||||
if dep not in self._rules:
|
||||
self._add_primitive(dep, dep_rule)
|
||||
return n
|
||||
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||
prop_order = self._prop_order
|
||||
# sort by position in prop_order (if specified) then by original order
|
||||
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||
|
||||
prop_kv_rule_names = {}
|
||||
for prop_name, prop_schema in properties:
|
||||
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
|
||||
)
|
||||
required_props = [k for k in sorted_props if k in required]
|
||||
optional_props = [k for k in sorted_props if k not in required]
|
||||
|
||||
if additional_properties == True or isinstance(additional_properties, dict):
|
||||
sub_name = f'{name}{"-" if name else ""}additional'
|
||||
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
|
||||
prop_kv_rule_names["*"] = self._add_rule(
|
||||
f'{sub_name}-kv',
|
||||
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
|
||||
)
|
||||
optional_props.append("*")
|
||||
|
||||
rule = '"{" space '
|
||||
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||
|
||||
if optional_props:
|
||||
rule += ' ('
|
||||
if required_props:
|
||||
rule += ' "," space ( '
|
||||
|
||||
def get_recursive_refs(ks, first_is_optional):
|
||||
[k, *rest] = ks
|
||||
kv_rule_name = prop_kv_rule_names[k]
|
||||
if k == '*':
|
||||
res = self._add_rule(
|
||||
f'{name}{"-" if name else ""}additional-kvs',
|
||||
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
|
||||
)
|
||||
elif first_is_optional:
|
||||
res = f'( "," space {kv_rule_name} )?'
|
||||
else:
|
||||
res = kv_rule_name
|
||||
if len(rest) > 0:
|
||||
res += ' ' + self._add_rule(
|
||||
f'{name}{"-" if name else ""}{k}-rest',
|
||||
get_recursive_refs(rest, first_is_optional=True)
|
||||
)
|
||||
return res
|
||||
|
||||
rule += ' | '.join(
|
||||
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||
for i in range(len(optional_props))
|
||||
)
|
||||
if required_props:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
|
||||
return rule
|
||||
|
||||
def format_grammar(self):
|
||||
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
|
||||
|
||||
|
||||
return '\n'.join(
|
||||
f'{name} ::= {rule}'
|
||||
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||
prop_order = prop_order or []
|
||||
schema = json.loads(schema)
|
||||
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||
converter = SchemaConverter(prop_order)
|
||||
converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
schema = converter.resolve_refs(schema, "stdin")
|
||||
converter.visit(schema, "")
|
||||
return converter.format_grammar()
|
||||
|
|
|
@ -59,7 +59,16 @@ def main():
|
|||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
# Check if yaml file
|
||||
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||
import yaml
|
||||
import json
|
||||
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||
json.dumps(yaml.safe_load(f))
|
||||
)
|
||||
else:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
else:
|
||||
|
|
|
@ -87,6 +87,13 @@ def get_llama_proxy():
|
|||
llama_outer_lock.release()
|
||||
|
||||
|
||||
_ping_message_factory = None
|
||||
|
||||
def set_ping_message_factory(factory):
|
||||
global _ping_message_factory
|
||||
_ping_message_factory = factory
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: Settings | None = None,
|
||||
server_settings: ServerSettings | None = None,
|
||||
|
@ -97,7 +104,15 @@ def create_app(
|
|||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
# Check if yaml file
|
||||
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||
import yaml
|
||||
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||
json.dumps(yaml.safe_load(f))
|
||||
)
|
||||
else:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
|
||||
|
@ -130,6 +145,9 @@ def create_app(
|
|||
assert model_settings is not None
|
||||
set_llama_proxy(model_settings=model_settings)
|
||||
|
||||
if server_settings.disable_ping_events:
|
||||
set_ping_message_factory(lambda: bytes())
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
@ -294,6 +312,7 @@ async def create_completion(
|
|||
iterator=iterator(),
|
||||
),
|
||||
sep="\n",
|
||||
ping_message_factory=_ping_message_factory,
|
||||
)
|
||||
else:
|
||||
return iterator_or_completion
|
||||
|
@ -462,6 +481,7 @@ async def create_chat_completion(
|
|||
iterator=iterator(),
|
||||
),
|
||||
sep="\n",
|
||||
ping_message_factory=_ping_message_factory,
|
||||
)
|
||||
else:
|
||||
return iterator_or_completion
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import multiprocessing
|
||||
|
||||
from typing import Optional, List, Literal, Union
|
||||
from pydantic import Field
|
||||
from pydantic import Field, root_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
import llama_cpp
|
||||
|
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
|
|||
n_threads: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=1,
|
||||
description="The number of threads to use.",
|
||||
description="The number of threads to use. Use -1 for max cpu threads",
|
||||
)
|
||||
n_threads_batch: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
default=max(multiprocessing.cpu_count(), 1),
|
||||
ge=0,
|
||||
description="The number of threads to use when batch processing.",
|
||||
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
|
||||
)
|
||||
rope_scaling_type: int = Field(
|
||||
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||
|
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
|
|||
default=True, description="Whether to print debug information."
|
||||
)
|
||||
|
||||
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
|
||||
def set_dynamic_defaults(cls, values):
|
||||
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
if values.get('n_threads', 0) == -1:
|
||||
values['n_threads'] = cpu_count
|
||||
if values.get('n_threads_batch', 0) == -1:
|
||||
values['n_threads_batch'] = cpu_count
|
||||
return values
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
||||
|
@ -195,6 +205,10 @@ class ServerSettings(BaseSettings):
|
|||
default=True,
|
||||
description="Whether to interrupt requests when a new request is received.",
|
||||
)
|
||||
disable_ping_events: bool = Field(
|
||||
default=False,
|
||||
description="Disable EventSource pings (may be needed for some clients).",
|
||||
)
|
||||
|
||||
|
||||
class Settings(ServerSettings, ModelSettings):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[build-system]
|
||||
requires = ["scikit-build-core[pyproject]>=0.5.1"]
|
||||
requires = ["scikit-build-core[pyproject]>=0.9.2"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
|
@ -35,6 +35,7 @@ server = [
|
|||
"pydantic-settings>=2.0.1",
|
||||
"sse-starlette>=1.6.1",
|
||||
"starlette-context>=0.3.6,<0.4",
|
||||
"PyYAML>=5.1",
|
||||
]
|
||||
test = [
|
||||
"pytest>=7.4.0",
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0
|
||||
Subproject commit 4e96a812b3ce7322a29a3008db2ed73d9087b176
|
Loading…
Reference in a new issue