This commit is contained in:
commit
ce85be97e2
14 changed files with 1201 additions and 191 deletions
31
.github/workflows/build-and-release.yaml
vendored
31
.github/workflows/build-and-release.yaml
vendored
|
@ -41,6 +41,35 @@ jobs:
|
||||||
with:
|
with:
|
||||||
path: ./wheelhouse/*.whl
|
path: ./wheelhouse/*.whl
|
||||||
|
|
||||||
|
build_arm64_wheels:
|
||||||
|
name: Build arm64 wheels
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: "recursive"
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
with:
|
||||||
|
platforms: linux/arm64
|
||||||
|
|
||||||
|
- name: Build wheels
|
||||||
|
uses: pypa/cibuildwheel@v2.16.5
|
||||||
|
env:
|
||||||
|
CIBW_SKIP: "*musllinux* pp*"
|
||||||
|
CIBW_REPAIR_WHEEL_COMMAND: ""
|
||||||
|
CIBW_ARCHS: "aarch64"
|
||||||
|
CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*"
|
||||||
|
with:
|
||||||
|
output-dir: wheelhouse/
|
||||||
|
|
||||||
|
- name: Upload wheels as artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: wheels-${{ matrix.version }}
|
||||||
|
path: wheelhouse/*.whl
|
||||||
|
|
||||||
build_sdist:
|
build_sdist:
|
||||||
name: Build source distribution
|
name: Build source distribution
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -65,7 +94,7 @@ jobs:
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: Release
|
name: Release
|
||||||
needs: [build_wheels, build_sdist]
|
needs: [build_wheels, build_arm64_wheels, build_sdist]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
29
CHANGELOG.md
29
CHANGELOG.md
|
@ -7,6 +7,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.2.64]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@4e96a812b3ce7322a29a3008db2ed73d9087b176
|
||||||
|
- feat: Add `llama-3` chat format by @andreabak in #1371
|
||||||
|
- feat: Use new llama_token_is_eog in create_completions by @abetlen in d40a250ef3cfaa8224d12c83776a2f1de96ae3d1
|
||||||
|
- feat(server): Provide ability to dynamically allocate all threads if desired using -1 by @sean-bailey in #1364
|
||||||
|
- ci: Build arm64 wheels by @gaby in 611781f5319719a3d05fefccbbf0cc321742a026
|
||||||
|
- fix: Update scikit-build-core build dependency avoid bug in 0.9.1 by @evelkey in #1370
|
||||||
|
|
||||||
|
## [0.2.63]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
|
||||||
|
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
|
||||||
|
|
||||||
|
## [0.2.62]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c
|
||||||
|
- feat: update grammar schema converter to match llama.cpp by @themrzmaster in #1353
|
||||||
|
- feat: add disable_ping_events flag by @khimaros in #1257
|
||||||
|
- feat: Make saved state more compact on-disk by @tc-wolf in #1296
|
||||||
|
- feat: Use all available CPUs for batch processing by @ddh0 in #1345
|
||||||
|
|
||||||
|
## [0.2.61]
|
||||||
|
|
||||||
|
- feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676
|
||||||
|
- fix: pass correct type to chat handlers for chat completion logprobs by @abetlen in bb65b4d76411112c6fb0bf759efd746f99ef3c6b
|
||||||
|
- feat: Add support for yaml based server configs by @abetlen in 060bfa64d529ade2af9b1f4e207a3937bbc4138f
|
||||||
|
- feat: Add typechecking for ctypes structure attributes by @abetlen in 1347e1d050fc5a9a32ffe0bb3e22858da28003bd
|
||||||
|
|
||||||
## [0.2.60]
|
## [0.2.60]
|
||||||
|
|
||||||
- feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0
|
- feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0
|
||||||
|
|
30
examples/batch-processing/server.py
Normal file
30
examples/batch-processing/server.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
"""llama-cpp-python server from scratch in a single file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# import llama_cpp
|
||||||
|
|
||||||
|
# path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf"
|
||||||
|
|
||||||
|
# model_params = llama_cpp.llama_model_default_params()
|
||||||
|
# model = llama_cpp.llama_load_model_from_file(path, model_params)
|
||||||
|
|
||||||
|
# if model is None:
|
||||||
|
# raise RuntimeError(f"Failed to load model from file: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ctx_params = llama_cpp.llama_context_default_params()
|
||||||
|
# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params)
|
||||||
|
|
||||||
|
# if ctx is None:
|
||||||
|
# raise RuntimeError("Failed to create context")
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
import openai.types.chat as types
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
def create_chat_completions():
|
||||||
|
return {"message": "Hello World"}
|
|
@ -1,4 +1,4 @@
|
||||||
from .llama_cpp import *
|
from .llama_cpp import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
|
|
||||||
__version__ = "0.2.60"
|
__version__ = "0.2.64"
|
|
@ -181,20 +181,20 @@ class _LlamaModel:
|
||||||
)
|
)
|
||||||
return list(tokens[:n_tokens])
|
return list(tokens[:n_tokens])
|
||||||
|
|
||||||
def token_to_piece(self, token: int) -> bytes:
|
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
buf = ctypes.create_string_buffer(32)
|
buf = ctypes.create_string_buffer(32)
|
||||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
|
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special)
|
||||||
return bytes(buf)
|
return bytes(buf)
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int]) -> bytes:
|
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
output = b""
|
output = b""
|
||||||
size = 32
|
size = 32
|
||||||
buffer = (ctypes.c_char * size)()
|
buffer = (ctypes.c_char * size)()
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
n = llama_cpp.llama_token_to_piece(
|
n = llama_cpp.llama_token_to_piece(
|
||||||
self.model, llama_cpp.llama_token(token), buffer, size
|
self.model, llama_cpp.llama_token(token), buffer, size, special
|
||||||
)
|
)
|
||||||
assert n <= size
|
assert n <= size
|
||||||
output += bytes(buffer[:n])
|
output += bytes(buffer[:n])
|
||||||
|
@ -597,13 +597,13 @@ def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> li
|
||||||
return list(result)
|
return list(result)
|
||||||
|
|
||||||
|
|
||||||
def _token_to_piece(model: _LlamaModel, token: int) -> str:
|
def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str:
|
||||||
assert model.model is not None
|
assert model.model is not None
|
||||||
result = (ctypes.c_char * 8)(0)
|
result = (ctypes.c_char * 8)(0)
|
||||||
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
||||||
if n_tokens < 0:
|
if n_tokens < 0:
|
||||||
result = (ctypes.c_char * -n_tokens)(0)
|
result = (ctypes.c_char * -n_tokens)(0)
|
||||||
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result))
|
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
||||||
if check != -n_tokens:
|
if check != -n_tokens:
|
||||||
raise RuntimeError(f"Failed to get piece: token={token}")
|
raise RuntimeError(f"Failed to get piece: token={token}")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -18,6 +18,7 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
Deque,
|
Deque,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
)
|
)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -262,9 +263,7 @@ class Llama:
|
||||||
|
|
||||||
self.n_batch = min(n_ctx, n_batch) # ???
|
self.n_batch = min(n_ctx, n_batch) # ???
|
||||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||||
self.n_threads_batch = n_threads_batch or max(
|
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
|
||||||
multiprocessing.cpu_count() // 2, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Context Params
|
# Context Params
|
||||||
self.context_params = llama_cpp.llama_context_default_params()
|
self.context_params = llama_cpp.llama_context_default_params()
|
||||||
|
@ -427,7 +426,10 @@ class Llama:
|
||||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||||
|
|
||||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||||
template=template, eos_token=eos_token, bos_token=bos_token
|
template=template,
|
||||||
|
eos_token=eos_token,
|
||||||
|
bos_token=bos_token,
|
||||||
|
stop_token_ids=[eos_token_id],
|
||||||
).to_chat_handler()
|
).to_chat_handler()
|
||||||
|
|
||||||
if self.chat_format is None and self.chat_handler is None:
|
if self.chat_format is None and self.chat_handler is None:
|
||||||
|
@ -1032,7 +1034,8 @@ class Llama:
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
):
|
):
|
||||||
if token == self._token_eos:
|
assert self._model.model is not None
|
||||||
|
if llama_cpp.llama_token_is_eog(self._model.model, token):
|
||||||
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
@ -1664,7 +1667,8 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
logprobs=top_logprobs if logprobs else None,
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -1792,7 +1796,7 @@ class Llama:
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return LlamaState(
|
return LlamaState(
|
||||||
scores=self.scores.copy(),
|
scores=self._scores.copy(),
|
||||||
input_ids=self.input_ids.copy(),
|
input_ids=self.input_ids.copy(),
|
||||||
n_tokens=self.n_tokens,
|
n_tokens=self.n_tokens,
|
||||||
llama_state=bytes(llama_state_compact),
|
llama_state=bytes(llama_state_compact),
|
||||||
|
@ -1801,7 +1805,9 @@ class Llama:
|
||||||
|
|
||||||
def load_state(self, state: LlamaState) -> None:
|
def load_state(self, state: LlamaState) -> None:
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
self.scores = state.scores.copy()
|
# Only filling in up to `n_tokens` and then zero-ing out the rest
|
||||||
|
self.scores[: state.n_tokens, :] = state.scores.copy()
|
||||||
|
self.scores[state.n_tokens :, :] = 0.0
|
||||||
self.input_ids = state.input_ids.copy()
|
self.input_ids = state.input_ids.copy()
|
||||||
self.n_tokens = state.n_tokens
|
self.n_tokens = state.n_tokens
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
|
@ -1952,7 +1958,6 @@ class Llama:
|
||||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join(local_dir, filename)
|
model_path = os.path.join(local_dir, filename)
|
||||||
|
|
|
@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
import llama_cpp.llama as llama
|
import llama_cpp.llama as llama
|
||||||
import llama_cpp.llama_types as llama_types
|
import llama_cpp.llama_types as llama_types
|
||||||
import llama_cpp.llama_grammar as llama_grammar
|
import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
@ -32,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
|
||||||
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||||
|
|
||||||
|
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||||
|
LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
|
||||||
|
|
||||||
### Chat Completion Handler ###
|
### Chat Completion Handler ###
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,6 +83,8 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -148,6 +156,7 @@ class ChatFormatterResponse:
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatter(Protocol):
|
class ChatFormatter(Protocol):
|
||||||
|
@ -171,12 +180,14 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
eos_token: str,
|
eos_token: str,
|
||||||
bos_token: str,
|
bos_token: str,
|
||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||||
self.template = template
|
self.template = template
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.bos_token = bos_token
|
self.bos_token = bos_token
|
||||||
self.add_generation_prompt = add_generation_prompt
|
self.add_generation_prompt = add_generation_prompt
|
||||||
|
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
|
||||||
|
|
||||||
self._environment = jinja2.Environment(
|
self._environment = jinja2.Environment(
|
||||||
loader=jinja2.BaseLoader(),
|
loader=jinja2.BaseLoader(),
|
||||||
|
@ -209,7 +220,16 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
stopping_criteria = None
|
||||||
|
if self.stop_token_ids is not None:
|
||||||
|
def stop_on_last_token(
|
||||||
|
tokens: npt.NDArray[np.intc],
|
||||||
|
logits: npt.NDArray[np.single]
|
||||||
|
) -> bool:
|
||||||
|
return tokens[-1] in self.stop_token_ids
|
||||||
|
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
|
||||||
|
|
||||||
|
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
|
||||||
|
|
||||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||||
return chat_formatter_to_chat_completion_handler(self)
|
return chat_formatter_to_chat_completion_handler(self)
|
||||||
|
@ -338,7 +358,7 @@ def _convert_completion_to_chat_function(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -391,7 +411,7 @@ def _convert_completion_to_chat_function(
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"logprobs": None,
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
"delta": {
|
"delta": {
|
||||||
"role": None,
|
"role": None,
|
||||||
"content": None,
|
"content": None,
|
||||||
|
@ -426,7 +446,7 @@ def _convert_completion_to_chat_function(
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"logprobs": None,
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
"delta": {
|
"delta": {
|
||||||
"role": None,
|
"role": None,
|
||||||
"content": None,
|
"content": None,
|
||||||
|
@ -491,7 +511,6 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
logprobs: int = 0,
|
|
||||||
min_p: float = 0.05,
|
min_p: float = 0.05,
|
||||||
typical_p: float = 1.0,
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
@ -512,6 +531,8 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -530,6 +551,10 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
stop = stop + rstop
|
stop = stop + rstop
|
||||||
|
|
||||||
|
stopping_criteria = None
|
||||||
|
if result.stopping_criteria is not None:
|
||||||
|
stopping_criteria = result.stopping_criteria
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
||||||
|
|
||||||
|
@ -581,7 +606,7 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
logprobs=logprobs,
|
logprobs=top_logprobs if logprobs else None,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -595,6 +620,7 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
@ -706,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
|
||||||
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
|
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
|
||||||
return "mistral-instruct"
|
return "mistral-instruct"
|
||||||
|
|
||||||
|
if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
|
||||||
|
return "llama-3"
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -897,6 +926,26 @@ def format_llama2(
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
# Chat format for Llama-3 models, see more details at:
|
||||||
|
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
|
||||||
|
@register_chat_format("llama-3")
|
||||||
|
def format_llama3(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
_roles = dict(
|
||||||
|
system="<|start_header_id|>system<|end_header_id|>\n\n",
|
||||||
|
user="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||||
|
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
)
|
||||||
|
_begin_token = "<|begin_of_text|>"
|
||||||
|
_sep = "<|eot_id|>"
|
||||||
|
_messages = _map_roles(messages, _roles)
|
||||||
|
_messages.append((_roles["assistant"], None))
|
||||||
|
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
|
||||||
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("alpaca")
|
@register_chat_format("alpaca")
|
||||||
def format_alpaca(
|
def format_alpaca(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -1628,7 +1677,7 @@ def functionary_chat_handler(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -2085,7 +2134,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
choices=[
|
choices=[
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None if content == "" else content,
|
"content": None if content == "" else content,
|
||||||
|
@ -2311,11 +2360,14 @@ def chatml_function_calling(
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
]:
|
]:
|
||||||
|
print(logprobs)
|
||||||
function_calling_template = (
|
function_calling_template = (
|
||||||
"{% for message in messages %}"
|
"{% for message in messages %}"
|
||||||
"<|im_start|>{{ message.role }}\n"
|
"<|im_start|>{{ message.role }}\n"
|
||||||
|
@ -2437,6 +2489,7 @@ def chatml_function_calling(
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
|
logprobs=top_logprobs if logprobs else None,
|
||||||
),
|
),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
@ -2549,6 +2602,7 @@ def chatml_function_calling(
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=["<|im_end|>"],
|
stop=["<|im_end|>"],
|
||||||
|
logprobs=top_logprobs if logprobs else None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -2660,7 +2714,7 @@ def chatml_function_calling(
|
||||||
{
|
{
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None,
|
"content": None,
|
||||||
|
|
|
@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
|
||||||
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
|
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
|
||||||
|
|
||||||
|
# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
|
||||||
|
|
||||||
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
|
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
|
||||||
# define LLAMA_SESSION_VERSION 5
|
# define LLAMA_SESSION_VERSION 5
|
||||||
LLAMA_SESSION_VERSION = 5
|
LLAMA_SESSION_VERSION = 5
|
||||||
|
|
||||||
|
# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
|
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
|
||||||
|
# define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
LLAMA_STATE_SEQ_VERSION = 1
|
||||||
|
|
||||||
# struct llama_model;
|
# struct llama_model;
|
||||||
llama_model_p = NewType("llama_model_p", int)
|
llama_model_p = NewType("llama_model_p", int)
|
||||||
|
@ -424,6 +431,11 @@ class llama_token_data(ctypes.Structure):
|
||||||
logit (float): log-odds of the token
|
logit (float): log-odds of the token
|
||||||
p (float): probability of the token"""
|
p (float): probability of the token"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
id: llama_token
|
||||||
|
logit: float
|
||||||
|
p: float
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("id", llama_token),
|
("id", llama_token),
|
||||||
("logit", ctypes.c_float),
|
("logit", ctypes.c_float),
|
||||||
|
@ -447,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
|
||||||
size (int): size of the array
|
size (int): size of the array
|
||||||
sorted (bool): whether the array is sorted"""
|
sorted (bool): whether the array is sorted"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
data: CtypesArray[llama_token_data]
|
||||||
|
size: int
|
||||||
|
sorted: bool
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("data", llama_token_data_p),
|
("data", llama_token_data_p),
|
||||||
("size", ctypes.c_size_t),
|
("size", ctypes.c_size_t),
|
||||||
|
@ -508,6 +525,15 @@ class llama_batch(ctypes.Structure):
|
||||||
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
|
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
n_tokens: int
|
||||||
|
token: CtypesArray[llama_token]
|
||||||
|
embd: CtypesArray[ctypes.c_float]
|
||||||
|
pos: CtypesArray[CtypesArray[llama_pos]]
|
||||||
|
n_seq_id: CtypesArray[ctypes.c_int]
|
||||||
|
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
|
||||||
|
logits: CtypesArray[ctypes.c_int8]
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("n_tokens", ctypes.c_int32),
|
("n_tokens", ctypes.c_int32),
|
||||||
("token", ctypes.POINTER(llama_token)),
|
("token", ctypes.POINTER(llama_token)),
|
||||||
|
@ -602,6 +628,18 @@ class llama_model_params(ctypes.Structure):
|
||||||
use_mmap (bool): use mmap if possible
|
use_mmap (bool): use mmap if possible
|
||||||
use_mlock (bool): force system to keep model in RAM"""
|
use_mlock (bool): force system to keep model in RAM"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
n_gpu_layers: int
|
||||||
|
split_mode: int
|
||||||
|
main_gpu: int
|
||||||
|
tensor_split: CtypesArray[ctypes.c_float]
|
||||||
|
progress_callback: Callable[[float, ctypes.c_void_p], bool]
|
||||||
|
progress_callback_user_data: ctypes.c_void_p
|
||||||
|
kv_overrides: CtypesArray[llama_model_kv_override]
|
||||||
|
vocab_only: bool
|
||||||
|
use_mmap: bool
|
||||||
|
use_mlock: bool
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("n_gpu_layers", ctypes.c_int32),
|
("n_gpu_layers", ctypes.c_int32),
|
||||||
("split_mode", ctypes.c_int),
|
("split_mode", ctypes.c_int),
|
||||||
|
@ -689,6 +727,34 @@ class llama_context_params(ctypes.Structure):
|
||||||
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
|
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
seed: int
|
||||||
|
n_ctx: int
|
||||||
|
n_batch: int
|
||||||
|
n_ubatch: int
|
||||||
|
n_seq_max: int
|
||||||
|
n_threads: int
|
||||||
|
n_threads_batch: int
|
||||||
|
rope_scaling_type: int
|
||||||
|
pooling_type: int
|
||||||
|
rope_freq_base: float
|
||||||
|
rope_freq_scale: float
|
||||||
|
yarn_ext_factor: float
|
||||||
|
yarn_attn_factor: float
|
||||||
|
yarn_beta_fast: float
|
||||||
|
yarn_beta_slow: float
|
||||||
|
yarn_orig_ctx: int
|
||||||
|
defrag_thold: float
|
||||||
|
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
|
||||||
|
cb_eval_user_data: ctypes.c_void_p
|
||||||
|
type_k: int
|
||||||
|
type_v: int
|
||||||
|
logits_all: bool
|
||||||
|
embeddings: bool
|
||||||
|
offload_kqv: bool
|
||||||
|
abort_callback: Callable[[ctypes.c_void_p], bool]
|
||||||
|
abort_callback_data: ctypes.c_void_p
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("seed", ctypes.c_uint32),
|
("seed", ctypes.c_uint32),
|
||||||
("n_ctx", ctypes.c_uint32),
|
("n_ctx", ctypes.c_uint32),
|
||||||
|
@ -764,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
|
||||||
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
|
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
nthread: int
|
||||||
|
ftype: int
|
||||||
|
output_tensor_type: int
|
||||||
|
token_embedding_type: int
|
||||||
|
allow_requantize: bool
|
||||||
|
quantize_output_tensor: bool
|
||||||
|
only_copy: bool
|
||||||
|
pure: bool
|
||||||
|
imatrix: ctypes.c_void_p
|
||||||
|
kv_overrides: ctypes.c_void_p
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("nthread", ctypes.c_int32),
|
("nthread", ctypes.c_int32),
|
||||||
("ftype", ctypes.c_int),
|
("ftype", ctypes.c_int),
|
||||||
|
@ -821,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
|
||||||
# uint32_t value; // Unicode code point or rule ID
|
# uint32_t value; // Unicode code point or rule ID
|
||||||
# } llama_grammar_element;
|
# } llama_grammar_element;
|
||||||
class llama_grammar_element(ctypes.Structure):
|
class llama_grammar_element(ctypes.Structure):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
type: int
|
||||||
|
value: int
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("type", ctypes.c_int),
|
("type", ctypes.c_int),
|
||||||
("value", ctypes.c_uint32),
|
("value", ctypes.c_uint32),
|
||||||
|
@ -844,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
|
||||||
# int32_t n_eval;
|
# int32_t n_eval;
|
||||||
# };
|
# };
|
||||||
class llama_timings(ctypes.Structure):
|
class llama_timings(ctypes.Structure):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
t_start_ms: float
|
||||||
|
t_end_ms: float
|
||||||
|
t_load_ms: float
|
||||||
|
t_sample_ms: float
|
||||||
|
t_p_eval_ms: float
|
||||||
|
t_eval_ms: float
|
||||||
|
n_sample: int
|
||||||
|
n_p_eval: int
|
||||||
|
n_eval: int
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("t_start_ms", ctypes.c_double),
|
("t_start_ms", ctypes.c_double),
|
||||||
("t_end_ms", ctypes.c_double),
|
("t_end_ms", ctypes.c_double),
|
||||||
|
@ -944,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
|
||||||
[ctypes.c_int],
|
[ctypes.c_int],
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
def llama_numa_init(numa: int, /): ...
|
def llama_numa_init(numa: int, /):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Call once at the end of the program - currently only used for MPI
|
# // Call once at the end of the program - currently only used for MPI
|
||||||
|
@ -969,7 +1063,8 @@ def llama_backend_free():
|
||||||
)
|
)
|
||||||
def llama_load_model_from_file(
|
def llama_load_model_from_file(
|
||||||
path_model: bytes, params: llama_model_params, /
|
path_model: bytes, params: llama_model_params, /
|
||||||
) -> Optional[llama_model_p]: ...
|
) -> Optional[llama_model_p]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API void llama_free_model(struct llama_model * model);
|
# LLAMA_API void llama_free_model(struct llama_model * model);
|
||||||
|
@ -978,7 +1073,8 @@ def llama_load_model_from_file(
|
||||||
[llama_model_p_ctypes],
|
[llama_model_p_ctypes],
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
def llama_free_model(model: llama_model_p, /): ...
|
def llama_free_model(model: llama_model_p, /):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API struct llama_context * llama_new_context_with_model(
|
# LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||||
|
@ -991,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
|
||||||
)
|
)
|
||||||
def llama_new_context_with_model(
|
def llama_new_context_with_model(
|
||||||
model: llama_model_p, params: llama_context_params, /
|
model: llama_model_p, params: llama_context_params, /
|
||||||
) -> Optional[llama_context_p]: ...
|
) -> Optional[llama_context_p]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Frees all allocated memory
|
# // Frees all allocated memory
|
||||||
|
@ -1012,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
|
||||||
[],
|
[],
|
||||||
ctypes.c_int64,
|
ctypes.c_int64,
|
||||||
)
|
)
|
||||||
def llama_time_us() -> int: ...
|
def llama_time_us() -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API size_t llama_max_devices(void);
|
# LLAMA_API size_t llama_max_devices(void);
|
||||||
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
|
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
|
||||||
def llama_max_devices() -> int: ...
|
def llama_max_devices() -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API bool llama_supports_mmap (void);
|
# LLAMA_API bool llama_supports_mmap (void);
|
||||||
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
|
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
|
||||||
def llama_supports_mmap() -> bool: ...
|
def llama_supports_mmap() -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API bool llama_supports_mlock (void);
|
# LLAMA_API bool llama_supports_mlock (void);
|
||||||
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
|
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
|
||||||
def llama_supports_mlock() -> bool: ...
|
def llama_supports_mlock() -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API bool llama_supports_gpu_offload(void);
|
# LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
|
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
|
||||||
def llama_supports_gpu_offload() -> bool: ...
|
def llama_supports_gpu_offload() -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||||
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
|
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
|
||||||
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
|
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
|
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||||
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
|
def llama_n_ctx(ctx: llama_context_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
|
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||||
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
|
def llama_n_batch(ctx: llama_context_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
|
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||||
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
|
def llama_n_ubatch(ctx: llama_context_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
|
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||||
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
|
def llama_n_seq_max(ctx: llama_context_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||||
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
|
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||||
def llama_vocab_type(model: llama_model_p, /) -> int: ...
|
def llama_vocab_type(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||||
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
|
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||||
def llama_rope_type(model: llama_model_p, /) -> int: ...
|
def llama_rope_type(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||||
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
|
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
|
||||||
def llama_n_vocab(model: llama_model_p, /) -> int: ...
|
def llama_n_vocab(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||||
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
|
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
|
||||||
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
|
def llama_n_ctx_train(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||||
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
|
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
|
||||||
def llama_n_embd(model: llama_model_p, /) -> int: ...
|
def llama_n_embd(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||||
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
|
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
|
||||||
def llama_n_layer(model: llama_model_p, /) -> int: ...
|
def llama_n_layer(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Get the model's RoPE frequency scaling factor
|
# // Get the model's RoPE frequency scaling factor
|
||||||
|
@ -1351,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
||||||
pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
|
pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
|
||||||
May be negative if the cell is not populated."""
|
May be negative if the cell is not populated."""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pos: llama_pos
|
||||||
|
|
||||||
_fields_ = [("pos", llama_pos)]
|
_fields_ = [("pos", llama_pos)]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1387,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
||||||
# llama_seq_id * cells_sequences;
|
# llama_seq_id * cells_sequences;
|
||||||
# };
|
# };
|
||||||
class llama_kv_cache_view(ctypes.Structure):
|
class llama_kv_cache_view(ctypes.Structure):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
n_cells: int
|
||||||
|
n_max_seq: int
|
||||||
|
token_count: int
|
||||||
|
used_cells: int
|
||||||
|
max_contiguous: int
|
||||||
|
max_contiguous_idx: int
|
||||||
|
cells: CtypesArray[llama_kv_cache_view_cell]
|
||||||
|
cells_sequences: CtypesArray[llama_seq_id]
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("n_cells", ctypes.c_int32),
|
("n_cells", ctypes.c_int32),
|
||||||
("n_max_seq", ctypes.c_int32),
|
("n_max_seq", ctypes.c_int32),
|
||||||
|
@ -1467,6 +1593,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
|
||||||
|
|
||||||
|
|
||||||
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
# // seq_id < 0 : match any sequence
|
# // seq_id < 0 : match any sequence
|
||||||
# // p0 < 0 : [0, p1]
|
# // p0 < 0 : [0, p1]
|
||||||
# // p1 < 0 : [p0, inf)
|
# // p1 < 0 : [p0, inf)
|
||||||
|
@ -1493,6 +1620,9 @@ def llama_kv_cache_seq_rm(
|
||||||
/,
|
/,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
|
||||||
|
Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
|
|
||||||
seq_id < 0 : match any sequence
|
seq_id < 0 : match any sequence
|
||||||
p0 < 0 : [0, p1]
|
p0 < 0 : [0, p1]
|
||||||
p1 < 0 : [p0, inf)"""
|
p1 < 0 : [p0, inf)"""
|
||||||
|
@ -1652,7 +1782,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
|
||||||
|
|
||||||
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
# and kv_cache) - will often be smaller after compacting tokens
|
# and kv_cache) - will often be smaller after compacting tokens
|
||||||
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
||||||
|
@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t)
|
||||||
|
def llama_state_get_size(ctx: llama_context_p, /) -> int:
|
||||||
|
"""Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
|
and kv_cache) - will often be smaller after compacting tokens"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
|
||||||
|
# "use llama_state_get_size instead");
|
||||||
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
|
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
|
||||||
def llama_get_state_size(ctx: llama_context_p, /) -> int:
|
def llama_get_state_size(ctx: llama_context_p, /) -> int:
|
||||||
"""Returns the maximum size in bytes of the state (rng, logits, embedding
|
"""Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
|
@ -1663,9 +1802,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
|
||||||
# Copies the state to the specified destination address.
|
# Copies the state to the specified destination address.
|
||||||
# Destination needs to have allocated enough memory.
|
# Destination needs to have allocated enough memory.
|
||||||
# Returns the number of bytes copied
|
# Returns the number of bytes copied
|
||||||
# LLAMA_API size_t llama_copy_state_data(
|
# LLAMA_API size_t llama_state_get_data(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# uint8_t * dst);
|
# uint8_t * dst);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_get_data",
|
||||||
|
[
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.POINTER(ctypes.c_uint8),
|
||||||
|
],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_get_data(
|
||||||
|
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
|
||||||
|
) -> int:
|
||||||
|
"""Copies the state to the specified destination address.
|
||||||
|
Destination needs to have allocated enough memory.
|
||||||
|
Returns the number of bytes copied"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# uint8_t * dst),
|
||||||
|
# "use llama_state_get_data instead");
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_copy_state_data",
|
"llama_copy_state_data",
|
||||||
[
|
[
|
||||||
|
@ -1685,9 +1845,26 @@ def llama_copy_state_data(
|
||||||
|
|
||||||
# // Set the state reading from the specified address
|
# // Set the state reading from the specified address
|
||||||
# // Returns the number of bytes read
|
# // Returns the number of bytes read
|
||||||
# LLAMA_API size_t llama_set_state_data(
|
# LLAMA_API size_t llama_state_set_data(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# const uint8_t * src);
|
# const uint8_t * src);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_set_data",
|
||||||
|
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_set_data(
|
||||||
|
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
|
||||||
|
) -> int:
|
||||||
|
"""Set the state reading from the specified address
|
||||||
|
Returns the number of bytes read"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const uint8_t * src),
|
||||||
|
# "use llama_state_set_data instead");
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_set_state_data",
|
"llama_set_state_data",
|
||||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
||||||
|
@ -1701,12 +1878,41 @@ def llama_set_state_data(
|
||||||
|
|
||||||
|
|
||||||
# Save/load session file
|
# Save/load session file
|
||||||
# LLAMA_API bool llama_load_session_file(
|
# LLAMA_API bool llama_state_load_file(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# const char * path_session,
|
# const char * path_session,
|
||||||
# llama_token * tokens_out,
|
# llama_token * tokens_out,
|
||||||
# size_t n_token_capacity,
|
# size_t n_token_capacity,
|
||||||
# size_t * n_token_count_out);
|
# size_t * n_token_count_out);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_load_file",
|
||||||
|
[
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
llama_token_p,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.POINTER(ctypes.c_size_t),
|
||||||
|
],
|
||||||
|
ctypes.c_bool,
|
||||||
|
)
|
||||||
|
def llama_state_load_file(
|
||||||
|
ctx: llama_context_p,
|
||||||
|
path_session: bytes,
|
||||||
|
tokens_out: CtypesArray[llama_token],
|
||||||
|
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||||
|
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||||
|
/,
|
||||||
|
) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API DEPRECATED(bool llama_load_session_file(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const char * path_session,
|
||||||
|
# llama_token * tokens_out,
|
||||||
|
# size_t n_token_capacity,
|
||||||
|
# size_t * n_token_count_out),
|
||||||
|
# "use llama_state_load_file instead");
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_load_session_file",
|
"llama_load_session_file",
|
||||||
[
|
[
|
||||||
|
@ -1725,14 +1931,41 @@ def llama_load_session_file(
|
||||||
n_token_capacity: Union[ctypes.c_size_t, int],
|
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||||
/,
|
/,
|
||||||
) -> int: ...
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API bool llama_save_session_file(
|
# LLAMA_API bool llama_state_save_file(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# const char * path_session,
|
# const char * path_session,
|
||||||
# const llama_token * tokens,
|
# const llama_token * tokens,
|
||||||
# size_t n_token_count);
|
# size_t n_token_count);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_save_file",
|
||||||
|
[
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
llama_token_p,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
],
|
||||||
|
ctypes.c_bool,
|
||||||
|
)
|
||||||
|
def llama_state_save_file(
|
||||||
|
ctx: llama_context_p,
|
||||||
|
path_session: bytes,
|
||||||
|
tokens: CtypesArray[llama_token],
|
||||||
|
n_token_count: Union[ctypes.c_size_t, int],
|
||||||
|
/,
|
||||||
|
) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API DEPRECATED(bool llama_save_session_file(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const char * path_session,
|
||||||
|
# const llama_token * tokens,
|
||||||
|
# size_t n_token_count),
|
||||||
|
# "use llama_state_save_file instead");
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_save_session_file",
|
"llama_save_session_file",
|
||||||
[
|
[
|
||||||
|
@ -1749,7 +1982,118 @@ def llama_save_session_file(
|
||||||
tokens: CtypesArray[llama_token],
|
tokens: CtypesArray[llama_token],
|
||||||
n_token_count: Union[ctypes.c_size_t, int],
|
n_token_count: Union[ctypes.c_size_t, int],
|
||||||
/,
|
/,
|
||||||
) -> int: ...
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# // Get the exact size needed to copy the KV cache of a single sequence
|
||||||
|
# LLAMA_API size_t llama_state_seq_get_size(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# llama_seq_id seq_id);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_seq_get_size",
|
||||||
|
[llama_context_p_ctypes, llama_seq_id],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
|
||||||
|
"""Get the exact size needed to copy the KV cache of a single sequence"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# // Copy the KV cache of a single sequence into the specified buffer
|
||||||
|
# LLAMA_API size_t llama_state_seq_get_data(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# uint8_t * dst,
|
||||||
|
# llama_seq_id seq_id);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_seq_get_data",
|
||||||
|
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_seq_get_data(
|
||||||
|
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, /
|
||||||
|
) -> int:
|
||||||
|
"""Copy the KV cache of a single sequence into the specified buffer"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||||
|
# // Returns:
|
||||||
|
# // - Positive: Ok
|
||||||
|
# // - Zero: Failed to load
|
||||||
|
# LLAMA_API size_t llama_state_seq_set_data(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const uint8_t * src,
|
||||||
|
# llama_seq_id dest_seq_id);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_seq_set_data",
|
||||||
|
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_seq_set_data(
|
||||||
|
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, /
|
||||||
|
) -> int:
|
||||||
|
"""Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API size_t llama_state_seq_save_file(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const char * filepath,
|
||||||
|
# llama_seq_id seq_id,
|
||||||
|
# const llama_token * tokens,
|
||||||
|
# size_t n_token_count);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_seq_save_file",
|
||||||
|
[
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
llama_seq_id,
|
||||||
|
llama_token_p,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_seq_save_file(
|
||||||
|
ctx: llama_context_p,
|
||||||
|
filepath: bytes,
|
||||||
|
seq_id: llama_seq_id,
|
||||||
|
tokens: CtypesArray[llama_token],
|
||||||
|
n_token_count: Union[ctypes.c_size_t, int],
|
||||||
|
/,
|
||||||
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API size_t llama_state_seq_load_file(
|
||||||
|
# struct llama_context * ctx,
|
||||||
|
# const char * filepath,
|
||||||
|
# llama_seq_id dest_seq_id,
|
||||||
|
# llama_token * tokens_out,
|
||||||
|
# size_t n_token_capacity,
|
||||||
|
# size_t * n_token_count_out);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_state_seq_load_file",
|
||||||
|
[
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
llama_seq_id,
|
||||||
|
llama_token_p,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.POINTER(ctypes.c_size_t),
|
||||||
|
],
|
||||||
|
ctypes.c_size_t,
|
||||||
|
)
|
||||||
|
def llama_state_seq_load_file(
|
||||||
|
ctx: llama_context_p,
|
||||||
|
filepath: bytes,
|
||||||
|
dest_seq_id: llama_seq_id,
|
||||||
|
tokens_out: CtypesArray[llama_token],
|
||||||
|
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||||
|
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||||
|
/,
|
||||||
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# //
|
# //
|
||||||
|
@ -1930,8 +2274,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Logits for the ith token. Equivalent to:
|
# // Logits for the ith token. For positive indices, Equivalent to:
|
||||||
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||||
|
# // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
|
||||||
# // returns NULL for invalid ids.
|
# // returns NULL for invalid ids.
|
||||||
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
|
@ -1963,8 +2308,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Get the embeddings for the ith token. Equivalent to:
|
# // Get the embeddings for the ith token. For positive indices, Equivalent to:
|
||||||
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||||
|
# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
|
||||||
# // shape: [n_embd] (1-dimensional)
|
# // shape: [n_embd] (1-dimensional)
|
||||||
# // returns NULL for invalid ids.
|
# // returns NULL for invalid ids.
|
||||||
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
@ -2010,7 +2356,8 @@ def llama_get_embeddings_seq(
|
||||||
)
|
)
|
||||||
def llama_token_get_text(
|
def llama_token_get_text(
|
||||||
model: llama_model_p, token: Union[llama_token, int], /
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
) -> bytes: ...
|
) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
||||||
|
@ -2019,7 +2366,8 @@ def llama_token_get_text(
|
||||||
)
|
)
|
||||||
def llama_token_get_score(
|
def llama_token_get_score(
|
||||||
model: llama_model_p, token: Union[llama_token, int], /
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
) -> float: ...
|
) -> float:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
||||||
|
@ -2028,7 +2376,20 @@ def llama_token_get_score(
|
||||||
)
|
)
|
||||||
def llama_token_get_type(
|
def llama_token_get_type(
|
||||||
model: llama_model_p, token: Union[llama_token, int], /
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
) -> int: ...
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
||||||
|
# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
||||||
|
@ctypes_function(
|
||||||
|
"llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool
|
||||||
|
)
|
||||||
|
def llama_token_is_eog(
|
||||||
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# // Special tokens
|
# // Special tokens
|
||||||
|
@ -2048,6 +2409,20 @@ def llama_token_eos(model: llama_model_p, /) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
||||||
|
@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token)
|
||||||
|
def llama_token_cls(model: llama_model_p, /) -> int:
|
||||||
|
"""classification"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
||||||
|
@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token)
|
||||||
|
def llama_token_sep(model: llama_model_p, /) -> int:
|
||||||
|
"""sentence separator"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||||
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
|
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
|
||||||
def llama_token_nl(model: llama_model_p, /) -> int:
|
def llama_token_nl(model: llama_model_p, /) -> int:
|
||||||
|
@ -2071,7 +2446,7 @@ def llama_add_eos_token(model: llama_model_p, /) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# // codellama infill tokens
|
# // Codellama infill tokens
|
||||||
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
||||||
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
|
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
|
||||||
def llama_token_prefix(model: llama_model_p) -> int:
|
def llama_token_prefix(model: llama_model_p) -> int:
|
||||||
|
@ -2081,17 +2456,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
|
||||||
|
|
||||||
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
||||||
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
|
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
|
||||||
def llama_token_middle(model: llama_model_p, /) -> int: ...
|
def llama_token_middle(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
||||||
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
|
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
|
||||||
def llama_token_suffix(model: llama_model_p, /) -> int: ...
|
def llama_token_suffix(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
|
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
|
||||||
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
|
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
|
||||||
def llama_token_eot(model: llama_model_p, /) -> int: ...
|
def llama_token_eot(model: llama_model_p, /) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# //
|
# //
|
||||||
|
@ -2103,16 +2481,16 @@ def llama_token_eot(model: llama_model_p, /) -> int: ...
|
||||||
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
# /// @return Returns the number of tokens on success, no more than n_tokens_max
|
# /// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||||
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
|
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||||
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
# /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||||
# /// Does not insert a leading space.
|
# /// as plaintext. Does not insert a leading space.
|
||||||
# LLAMA_API int32_t llama_tokenize(
|
# LLAMA_API int32_t llama_tokenize(
|
||||||
# const struct llama_model * model,
|
# const struct llama_model * model,
|
||||||
# const char * text,
|
# const char * text,
|
||||||
# int32_t text_len,
|
# int32_t text_len,
|
||||||
# llama_token * tokens,
|
# llama_token * tokens,
|
||||||
# int32_t n_tokens_max,
|
# int32_t n_tokens_max,
|
||||||
# bool add_bos,
|
# bool add_special,
|
||||||
# bool special);
|
# bool parse_special);
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_tokenize",
|
"llama_tokenize",
|
||||||
[
|
[
|
||||||
|
@ -2132,8 +2510,8 @@ def llama_tokenize(
|
||||||
text_len: Union[ctypes.c_int, int],
|
text_len: Union[ctypes.c_int, int],
|
||||||
tokens: CtypesArray[llama_token],
|
tokens: CtypesArray[llama_token],
|
||||||
n_tokens_max: Union[ctypes.c_int, int],
|
n_tokens_max: Union[ctypes.c_int, int],
|
||||||
add_bos: Union[ctypes.c_bool, bool],
|
add_special: Union[ctypes.c_bool, bool],
|
||||||
special: Union[ctypes.c_bool, bool],
|
parse_special: Union[ctypes.c_bool, bool],
|
||||||
/,
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Convert the provided text into tokens.
|
"""Convert the provided text into tokens.
|
||||||
|
@ -2144,9 +2522,8 @@ def llama_tokenize(
|
||||||
text_len: The length of the text.
|
text_len: The length of the text.
|
||||||
tokens: The tokens pointer must be large enough to hold the resulting tokens.
|
tokens: The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
n_max_tokens: The maximum number of tokens to return.
|
n_max_tokens: The maximum number of tokens to return.
|
||||||
add_bos: Whether to add a beginning-of-sentence token.
|
add_special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
|
||||||
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
parse_special: Allow parsing special tokens.
|
||||||
Does not insert a leading space.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns the number of tokens on success, no more than n_tokens_max
|
Returns the number of tokens on success, no more than n_tokens_max
|
||||||
|
@ -2159,11 +2536,13 @@ def llama_tokenize(
|
||||||
# // Uses the vocabulary in the provided context.
|
# // Uses the vocabulary in the provided context.
|
||||||
# // Does not write null terminator to the buffer.
|
# // Does not write null terminator to the buffer.
|
||||||
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||||
|
# // @param special If true, special tokens are rendered in the output.
|
||||||
# LLAMA_API int32_t llama_token_to_piece(
|
# LLAMA_API int32_t llama_token_to_piece(
|
||||||
# const struct llama_model * model,
|
# const struct llama_model * model,
|
||||||
# llama_token token,
|
# llama_token token,
|
||||||
# char * buf,
|
# char * buf,
|
||||||
# int32_t length);
|
# int32_t length,
|
||||||
|
# bool special);
|
||||||
@ctypes_function(
|
@ctypes_function(
|
||||||
"llama_token_to_piece",
|
"llama_token_to_piece",
|
||||||
[
|
[
|
||||||
|
@ -2171,6 +2550,7 @@ def llama_tokenize(
|
||||||
llama_token,
|
llama_token,
|
||||||
ctypes.c_char_p,
|
ctypes.c_char_p,
|
||||||
ctypes.c_int32,
|
ctypes.c_int32,
|
||||||
|
ctypes.c_bool,
|
||||||
],
|
],
|
||||||
ctypes.c_int32,
|
ctypes.c_int32,
|
||||||
)
|
)
|
||||||
|
@ -2179,13 +2559,20 @@ def llama_token_to_piece(
|
||||||
token: Union[llama_token, int],
|
token: Union[llama_token, int],
|
||||||
buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
|
buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
|
||||||
length: Union[ctypes.c_int, int],
|
length: Union[ctypes.c_int, int],
|
||||||
|
special: Union[ctypes.c_bool, bool],
|
||||||
/,
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Token Id -> Piece.
|
"""Token Id -> Piece.
|
||||||
Uses the vocabulary in the provided context.
|
Uses the vocabulary in the provided context.
|
||||||
Does not write null terminator to the buffer.
|
Does not write null terminator to the buffer.
|
||||||
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model: The model to use for tokenization.
|
||||||
|
token: The token to convert.
|
||||||
|
buf: The buffer to write the token to.
|
||||||
|
length: The length of the buffer.
|
||||||
|
special: If true, special tokens are rendered in the output."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -2223,7 +2610,8 @@ def llama_chat_apply_template(
|
||||||
chat: CtypesArray[llama_chat_message],
|
chat: CtypesArray[llama_chat_message],
|
||||||
n_msg: int,
|
n_msg: int,
|
||||||
/,
|
/,
|
||||||
) -> int: ...
|
) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# //
|
# //
|
||||||
|
@ -2753,6 +3141,12 @@ def llama_grammar_accept_token(
|
||||||
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||||
# };
|
# };
|
||||||
class llama_beam_view(ctypes.Structure):
|
class llama_beam_view(ctypes.Structure):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
tokens: CtypesArray[llama_token]
|
||||||
|
n_tokens: int
|
||||||
|
p: float
|
||||||
|
eob: bool
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("tokens", llama_token_p),
|
("tokens", llama_token_p),
|
||||||
("n_tokens", ctypes.c_size_t),
|
("n_tokens", ctypes.c_size_t),
|
||||||
|
@ -2772,6 +3166,12 @@ class llama_beam_view(ctypes.Structure):
|
||||||
# bool last_call; // True iff this is the last callback invocation.
|
# bool last_call; // True iff this is the last callback invocation.
|
||||||
# };
|
# };
|
||||||
class llama_beams_state(ctypes.Structure):
|
class llama_beams_state(ctypes.Structure):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
beam_views: CtypesArray[llama_beam_view]
|
||||||
|
n_beams: int
|
||||||
|
common_prefix_length: int
|
||||||
|
last_call: bool
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("beam_views", ctypes.POINTER(llama_beam_view)),
|
("beam_views", ctypes.POINTER(llama_beam_view)),
|
||||||
("n_beams", ctypes.c_size_t),
|
("n_beams", ctypes.c_size_t),
|
||||||
|
@ -2824,7 +3224,8 @@ def llama_beam_search(
|
||||||
n_past: Union[ctypes.c_int, int],
|
n_past: Union[ctypes.c_int, int],
|
||||||
n_predict: Union[ctypes.c_int, int],
|
n_predict: Union[ctypes.c_int, int],
|
||||||
/,
|
/,
|
||||||
): ...
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# /// @details Build a split GGUF final path for this chunk.
|
# /// @details Build a split GGUF final path for this chunk.
|
||||||
|
@ -2943,4 +3344,5 @@ def llama_log_set(
|
||||||
[ctypes.c_void_p, llama_context_p_ctypes],
|
[ctypes.c_void_p, llama_context_p_ctypes],
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
|
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
|
||||||
|
...
|
||||||
|
|
|
@ -5,11 +5,12 @@ from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
from ctypes import * # type: ignore
|
from ctypes import * # type: ignore
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import islice
|
from itertools import islice, groupby
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Set,
|
||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -1391,145 +1392,561 @@ from typing import List, Optional
|
||||||
# whitespace. Also maybe improves generation quality?
|
# whitespace. Also maybe improves generation quality?
|
||||||
SPACE_RULE = '" "?'
|
SPACE_RULE = '" "?'
|
||||||
|
|
||||||
PRIMITIVE_RULES = {
|
|
||||||
"boolean": '("true" | "false") space',
|
|
||||||
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
|
||||||
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
|
||||||
"string": r""" "\"" (
|
|
||||||
[^"\\] |
|
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
|
||||||
)* "\"" space """,
|
|
||||||
"null": '"null" space',
|
|
||||||
}
|
|
||||||
|
|
||||||
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||||
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||||
|
|
||||||
|
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||||
|
# whitespace. Also maybe improves generation quality?
|
||||||
|
SPACE_RULE = '" "?'
|
||||||
|
|
||||||
|
|
||||||
|
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
|
||||||
|
if not separator_rule:
|
||||||
|
if min_items == 0 and max_items == 1:
|
||||||
|
return f'{item_rule}?'
|
||||||
|
elif min_items == 1 and max_items is None:
|
||||||
|
return f'{item_rule}+'
|
||||||
|
|
||||||
|
result = ''
|
||||||
|
|
||||||
|
if min_items > 0:
|
||||||
|
if item_rule_is_literal and separator_rule is None:
|
||||||
|
result = '"' + (item_rule[1:-1] * min_items) + '"'
|
||||||
|
else:
|
||||||
|
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
|
||||||
|
|
||||||
|
def opt_repetitions(up_to_n, prefix_with_sep=False):
|
||||||
|
'''
|
||||||
|
- n=4, no sep: '(a (a (a (a)?)?)?)?'
|
||||||
|
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
|
||||||
|
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
|
||||||
|
'''
|
||||||
|
|
||||||
|
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
|
||||||
|
if up_to_n == 0:
|
||||||
|
return ''
|
||||||
|
elif up_to_n == 1:
|
||||||
|
return f'({content})?'
|
||||||
|
elif separator_rule and not prefix_with_sep:
|
||||||
|
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
|
||||||
|
else:
|
||||||
|
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
|
||||||
|
|
||||||
|
if min_items > 0 and max_items != min_items:
|
||||||
|
result += ' '
|
||||||
|
|
||||||
|
if max_items is not None:
|
||||||
|
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
|
||||||
|
else:
|
||||||
|
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
|
||||||
|
|
||||||
|
if min_items == 0 and separator_rule:
|
||||||
|
result = f'({item_rule} {item_operator}*)?'
|
||||||
|
else:
|
||||||
|
result += f'{item_operator}*'
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinRule:
|
||||||
|
def __init__(self, content: str, deps: list = None):
|
||||||
|
self.content = content
|
||||||
|
self.deps = deps or []
|
||||||
|
|
||||||
|
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
||||||
|
|
||||||
|
PRIMITIVE_RULES = {
|
||||||
|
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||||
|
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
|
||||||
|
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
|
||||||
|
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||||
|
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||||
|
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||||
|
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||||
|
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||||
|
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
|
||||||
|
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
|
||||||
|
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||||
|
'null' : BuiltinRule('"null" space', []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: support "uri", "email" string formats
|
||||||
|
STRING_FORMAT_RULES = {
|
||||||
|
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||||
|
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||||
|
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||||
|
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||||
|
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||||
|
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||||
|
}
|
||||||
|
|
||||||
|
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||||
|
DOT = '[^\\x0A\\x0D]'
|
||||||
|
|
||||||
|
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||||
|
|
||||||
|
|
||||||
|
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||||
|
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaConverter:
|
class SchemaConverter:
|
||||||
def __init__(self, prop_order):
|
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||||
self._prop_order = prop_order
|
self._prop_order = prop_order
|
||||||
self._rules = {"space": SPACE_RULE}
|
self._allow_fetch = allow_fetch
|
||||||
self._defs: Dict[str, Any] = {}
|
self._dotall = dotall
|
||||||
|
self._raw_pattern = raw_pattern
|
||||||
|
self._rules = {
|
||||||
|
'space': SPACE_RULE,
|
||||||
|
}
|
||||||
|
self._refs = {}
|
||||||
|
self._refs_being_resolved = set()
|
||||||
|
|
||||||
def _format_literal(self, literal: str):
|
def _format_literal(self, literal):
|
||||||
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||||
)
|
)
|
||||||
return f'"{escaped}"'
|
return f'"{escaped}"'
|
||||||
|
|
||||||
def _add_rule(self, name: str, rule: str):
|
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
|
||||||
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
'''
|
||||||
|
not_literal('a') -> '[^a]'
|
||||||
|
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||||
|
'''
|
||||||
|
assert len(literal) > 0, 'Empty literal not supported'
|
||||||
|
def recurse(i: int):
|
||||||
|
c = literal[i]
|
||||||
|
if maybe_escaped_underscores and c == '_':
|
||||||
|
yield f'[^{c}\\\\]'
|
||||||
|
yield ' | '
|
||||||
|
yield f'"\\\\"? "{c}"'
|
||||||
|
else:
|
||||||
|
yield f'[^{c}]'
|
||||||
|
if i < len(literal) - 1:
|
||||||
|
yield ' | '
|
||||||
|
yield self._format_literal(c)
|
||||||
|
yield ' ('
|
||||||
|
yield from recurse(i + 1)
|
||||||
|
yield ')?'
|
||||||
|
|
||||||
|
return ''.join(('(', *recurse(0), ')'))
|
||||||
|
|
||||||
|
def _add_rule(self, name, rule):
|
||||||
|
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
|
||||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||||
key = esc_name
|
key = esc_name
|
||||||
else:
|
else:
|
||||||
i = 0
|
i = 0
|
||||||
while f"{esc_name}{i}" in self._rules:
|
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
|
||||||
i += 1
|
i += 1
|
||||||
key = f"{esc_name}{i}"
|
key = f'{esc_name}{i}'
|
||||||
self._rules[key] = rule
|
self._rules[key] = rule
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
def resolve_refs(self, schema: dict, url: str):
|
||||||
rule_name = name or "root"
|
'''
|
||||||
|
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
|
replacing $ref with absolute reference URL and populating self._refs with the
|
||||||
|
respective referenced (sub)schema dictionaries.
|
||||||
|
'''
|
||||||
|
def visit(n: dict):
|
||||||
|
if isinstance(n, list):
|
||||||
|
return [visit(x) for x in n]
|
||||||
|
elif isinstance(n, dict):
|
||||||
|
ref = n.get('$ref')
|
||||||
|
if ref is not None and ref not in self._refs:
|
||||||
|
if ref.startswith('https://'):
|
||||||
|
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||||
|
import requests
|
||||||
|
|
||||||
if "$defs" in schema:
|
frag_split = ref.split('#')
|
||||||
# add defs to self._defs for later inlining
|
base_url = frag_split[0]
|
||||||
for def_name, def_schema in schema["$defs"].items():
|
|
||||||
self._defs[def_name] = def_schema
|
|
||||||
|
|
||||||
if "oneOf" in schema or "anyOf" in schema:
|
target = self._refs.get(base_url)
|
||||||
rule = " | ".join(
|
if target is None:
|
||||||
(
|
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||||
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
self._refs[base_url] = target
|
||||||
for i, alt_schema in enumerate(
|
|
||||||
schema.get("oneOf") or schema["anyOf"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return self._add_rule(rule_name, rule)
|
|
||||||
|
|
||||||
elif "const" in schema:
|
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||||
return self._add_rule(rule_name, self._format_literal(schema["const"]))
|
return target
|
||||||
|
elif ref.startswith('#/'):
|
||||||
elif "enum" in schema:
|
target = schema
|
||||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
ref = f'{url}{ref}'
|
||||||
return self._add_rule(rule_name, rule)
|
n['$ref'] = ref
|
||||||
|
|
||||||
elif "$ref" in schema:
|
|
||||||
ref = schema["$ref"]
|
|
||||||
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
|
||||||
# inline $defs
|
|
||||||
def_name = ref[len("#/$defs/") :]
|
|
||||||
def_schema = self._defs[def_name]
|
|
||||||
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
|
||||||
|
|
||||||
|
|
||||||
schema_type: Optional[str] = schema.get("type") # type: ignore
|
|
||||||
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
|
|
||||||
|
|
||||||
if schema_type == "object" and "properties" in schema:
|
|
||||||
# TODO: `required` keyword
|
|
||||||
if self._prop_order:
|
|
||||||
prop_order = self._prop_order
|
|
||||||
prop_pairs = sorted(
|
|
||||||
schema["properties"].items(),
|
|
||||||
# sort by position in prop_order (if specified) then by key
|
|
||||||
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prop_pairs = schema["properties"].items()
|
raise ValueError(f'Unsupported ref {ref}')
|
||||||
|
|
||||||
rule = '"{" space'
|
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||||
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
prop_rule_name = self.visit(
|
target = target[sel]
|
||||||
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
|
||||||
|
self._refs[ref] = target
|
||||||
|
else:
|
||||||
|
for v in n.values():
|
||||||
|
visit(v)
|
||||||
|
|
||||||
|
return n
|
||||||
|
return visit(schema)
|
||||||
|
|
||||||
|
def _generate_union_rule(self, name, alt_schemas):
|
||||||
|
return ' | '.join((
|
||||||
|
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||||
|
for i, alt_schema in enumerate(alt_schemas)
|
||||||
|
))
|
||||||
|
|
||||||
|
def _visit_pattern(self, pattern, name):
|
||||||
|
'''
|
||||||
|
Transforms a regular expression pattern into a GBNF rule.
|
||||||
|
|
||||||
|
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||||
|
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||||
|
|
||||||
|
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||||
|
|
||||||
|
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||||
|
we define sub-rules to keep the output lean.
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||||
|
pattern = pattern[1:-1]
|
||||||
|
sub_rule_ids = {}
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
length = len(pattern)
|
||||||
|
|
||||||
|
def to_rule(s: Tuple[str, bool]) -> str:
|
||||||
|
(txt, is_literal) = s
|
||||||
|
return "\"" + txt + "\"" if is_literal else txt
|
||||||
|
|
||||||
|
def transform() -> Tuple[str, bool]:
|
||||||
|
'''
|
||||||
|
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||||
|
'''
|
||||||
|
nonlocal i
|
||||||
|
nonlocal pattern
|
||||||
|
nonlocal sub_rule_ids
|
||||||
|
|
||||||
|
start = i
|
||||||
|
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||||
|
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||||
|
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||||
|
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||||
|
seq: list[Tuple[str, bool]] = []
|
||||||
|
|
||||||
|
def get_dot():
|
||||||
|
if self._dotall:
|
||||||
|
rule = DOTALL
|
||||||
|
else:
|
||||||
|
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||||
|
rule = DOT
|
||||||
|
return self._add_rule(f'dot', rule)
|
||||||
|
|
||||||
|
def join_seq():
|
||||||
|
nonlocal seq
|
||||||
|
ret = []
|
||||||
|
for is_literal, g in groupby(seq, lambda x: x[1]):
|
||||||
|
if is_literal:
|
||||||
|
ret.append((''.join(x[0] for x in g), True))
|
||||||
|
else:
|
||||||
|
ret.extend(g)
|
||||||
|
if len(ret) == 1:
|
||||||
|
return ret[0]
|
||||||
|
return (' '.join(to_rule(x) for x in seq), False)
|
||||||
|
|
||||||
|
while i < length:
|
||||||
|
c = pattern[i]
|
||||||
|
if c == '.':
|
||||||
|
seq.append((get_dot(), False))
|
||||||
|
i += 1
|
||||||
|
elif c == '(':
|
||||||
|
i += 1
|
||||||
|
if i < length:
|
||||||
|
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||||
|
seq.append((f'({to_rule(transform())})', False))
|
||||||
|
elif c == ')':
|
||||||
|
i += 1
|
||||||
|
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
return join_seq()
|
||||||
|
elif c == '[':
|
||||||
|
square_brackets = c
|
||||||
|
i += 1
|
||||||
|
while i < length and pattern[i] != ']':
|
||||||
|
if pattern[i] == '\\':
|
||||||
|
square_brackets += pattern[i:i+2]
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
square_brackets += pattern[i]
|
||||||
|
i += 1
|
||||||
|
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
square_brackets += ']'
|
||||||
|
i += 1
|
||||||
|
seq.append((square_brackets, False))
|
||||||
|
elif c == '|':
|
||||||
|
seq.append(('|', False))
|
||||||
|
i += 1
|
||||||
|
elif c in ('*', '+', '?'):
|
||||||
|
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||||
|
i += 1
|
||||||
|
elif c == '{':
|
||||||
|
curly_brackets = c
|
||||||
|
i += 1
|
||||||
|
while i < length and pattern[i] != '}':
|
||||||
|
curly_brackets += pattern[i]
|
||||||
|
i += 1
|
||||||
|
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
curly_brackets += '}'
|
||||||
|
i += 1
|
||||||
|
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
|
||||||
|
min_times = 0
|
||||||
|
max_times = None
|
||||||
|
try:
|
||||||
|
if len(nums) == 1:
|
||||||
|
min_times = int(nums[0])
|
||||||
|
max_times = min_times
|
||||||
|
else:
|
||||||
|
assert len(nums) == 2
|
||||||
|
min_times = int(nums[0]) if nums[0] else 0
|
||||||
|
max_times = int(nums[1]) if nums[1] else None
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
|
||||||
|
|
||||||
|
(sub, sub_is_literal) = seq[-1]
|
||||||
|
|
||||||
|
if not sub_is_literal:
|
||||||
|
id = sub_rule_ids.get(sub)
|
||||||
|
if id is None:
|
||||||
|
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
|
||||||
|
sub_rule_ids[sub] = id
|
||||||
|
sub = id
|
||||||
|
|
||||||
|
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
|
||||||
|
else:
|
||||||
|
literal = ''
|
||||||
|
while i < length:
|
||||||
|
if pattern[i] == '\\' and i < length - 1:
|
||||||
|
next = pattern[i + 1]
|
||||||
|
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||||
|
i += 1
|
||||||
|
literal += pattern[i]
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
literal += pattern[i:i+2]
|
||||||
|
i += 2
|
||||||
|
elif pattern[i] == '"' and not self._raw_pattern:
|
||||||
|
literal += '\\"'
|
||||||
|
i += 1
|
||||||
|
elif pattern[i] not in NON_LITERAL_SET and \
|
||||||
|
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
|
||||||
|
literal += pattern[i]
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if literal:
|
||||||
|
seq.append((literal, True))
|
||||||
|
|
||||||
|
return join_seq()
|
||||||
|
|
||||||
|
return self._add_rule(
|
||||||
|
name,
|
||||||
|
to_rule(transform()) if self._raw_pattern \
|
||||||
|
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ref(self, ref):
|
||||||
|
ref_name = ref.split('/')[-1]
|
||||||
|
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||||
|
self._refs_being_resolved.add(ref)
|
||||||
|
resolved = self._refs[ref]
|
||||||
|
ref_name = self.visit(resolved, ref_name)
|
||||||
|
self._refs_being_resolved.remove(ref)
|
||||||
|
return ref_name
|
||||||
|
|
||||||
|
def _generate_constant_rule(self, value):
|
||||||
|
return self._format_literal(json.dumps(value))
|
||||||
|
|
||||||
|
def visit(self, schema, name):
|
||||||
|
schema_type = schema.get('type')
|
||||||
|
schema_format = schema.get('format')
|
||||||
|
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
|
||||||
|
|
||||||
|
if (ref := schema.get('$ref')) is not None:
|
||||||
|
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||||
|
|
||||||
|
elif 'oneOf' in schema or 'anyOf' in schema:
|
||||||
|
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||||
|
|
||||||
|
elif isinstance(schema_type, list):
|
||||||
|
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
|
||||||
|
|
||||||
|
elif 'const' in schema:
|
||||||
|
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||||
|
|
||||||
|
elif 'enum' in schema:
|
||||||
|
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
elif schema_type in (None, 'object') and \
|
||||||
|
('properties' in schema or \
|
||||||
|
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
|
||||||
|
required = set(schema.get('required', []))
|
||||||
|
properties = list(schema.get('properties', {}).items())
|
||||||
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||||
|
|
||||||
|
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||||
|
required = set()
|
||||||
|
properties = []
|
||||||
|
hybrid_name = name
|
||||||
|
def add_component(comp_schema, is_required):
|
||||||
|
if (ref := comp_schema.get('$ref')) is not None:
|
||||||
|
comp_schema = self._refs[ref]
|
||||||
|
|
||||||
|
if 'properties' in comp_schema:
|
||||||
|
for prop_name, prop_schema in comp_schema['properties'].items():
|
||||||
|
properties.append((prop_name, prop_schema))
|
||||||
|
if is_required:
|
||||||
|
required.add(prop_name)
|
||||||
|
|
||||||
|
for t in schema['allOf']:
|
||||||
|
if 'anyOf' in t:
|
||||||
|
for tt in t['anyOf']:
|
||||||
|
add_component(tt, is_required=False)
|
||||||
|
else:
|
||||||
|
add_component(t, is_required=True)
|
||||||
|
|
||||||
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
|
||||||
|
|
||||||
|
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||||
|
items = schema.get('items') or schema['prefixItems']
|
||||||
|
if isinstance(items, list):
|
||||||
|
return self._add_rule(
|
||||||
|
rule_name,
|
||||||
|
'"[" space ' +
|
||||||
|
' "," space '.join(
|
||||||
|
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||||
|
for i, item in enumerate(items)) +
|
||||||
|
' "]" space')
|
||||||
|
else:
|
||||||
|
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||||
|
min_items = schema.get("minItems", 0)
|
||||||
|
max_items = schema.get("maxItems")
|
||||||
|
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||||
|
|
||||||
|
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||||
|
return self._visit_pattern(schema['pattern'], rule_name)
|
||||||
|
|
||||||
|
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
|
||||||
|
return self._add_primitive(
|
||||||
|
'root' if rule_name == 'root' else schema_format,
|
||||||
|
PRIMITIVE_RULES['uuid']
|
||||||
)
|
)
|
||||||
if i > 0:
|
|
||||||
rule += ' "," space'
|
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
|
||||||
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
prim_name = f'{schema_format}-string'
|
||||||
|
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
|
||||||
|
|
||||||
|
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
|
||||||
|
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||||
|
min_len = schema.get('minLength', 0)
|
||||||
|
max_len = schema.get('maxLength')
|
||||||
|
|
||||||
|
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||||
|
|
||||||
|
elif (schema_type == 'object') or (len(schema) == 0):
|
||||||
|
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||||
|
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||||
|
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
|
||||||
|
|
||||||
|
def _add_primitive(self, name: str, rule: BuiltinRule):
|
||||||
|
n = self._add_rule(name, rule.content)
|
||||||
|
|
||||||
|
for dep in rule.deps:
|
||||||
|
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
||||||
|
assert dep_rule, f'Rule {dep} not known'
|
||||||
|
if dep not in self._rules:
|
||||||
|
self._add_primitive(dep, dep_rule)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||||
|
prop_order = self._prop_order
|
||||||
|
# sort by position in prop_order (if specified) then by original order
|
||||||
|
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||||
|
|
||||||
|
prop_kv_rule_names = {}
|
||||||
|
for prop_name, prop_schema in properties:
|
||||||
|
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||||
|
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||||
|
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
|
||||||
|
)
|
||||||
|
required_props = [k for k in sorted_props if k in required]
|
||||||
|
optional_props = [k for k in sorted_props if k not in required]
|
||||||
|
|
||||||
|
if additional_properties == True or isinstance(additional_properties, dict):
|
||||||
|
sub_name = f'{name}{"-" if name else ""}additional'
|
||||||
|
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
|
||||||
|
prop_kv_rule_names["*"] = self._add_rule(
|
||||||
|
f'{sub_name}-kv',
|
||||||
|
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
|
||||||
|
)
|
||||||
|
optional_props.append("*")
|
||||||
|
|
||||||
|
rule = '"{" space '
|
||||||
|
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||||
|
|
||||||
|
if optional_props:
|
||||||
|
rule += ' ('
|
||||||
|
if required_props:
|
||||||
|
rule += ' "," space ( '
|
||||||
|
|
||||||
|
def get_recursive_refs(ks, first_is_optional):
|
||||||
|
[k, *rest] = ks
|
||||||
|
kv_rule_name = prop_kv_rule_names[k]
|
||||||
|
if k == '*':
|
||||||
|
res = self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}additional-kvs',
|
||||||
|
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
|
||||||
|
)
|
||||||
|
elif first_is_optional:
|
||||||
|
res = f'( "," space {kv_rule_name} )?'
|
||||||
|
else:
|
||||||
|
res = kv_rule_name
|
||||||
|
if len(rest) > 0:
|
||||||
|
res += ' ' + self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}{k}-rest',
|
||||||
|
get_recursive_refs(rest, first_is_optional=True)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
rule += ' | '.join(
|
||||||
|
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||||
|
for i in range(len(optional_props))
|
||||||
|
)
|
||||||
|
if required_props:
|
||||||
|
rule += ' )'
|
||||||
|
rule += ' )?'
|
||||||
|
|
||||||
rule += ' "}" space'
|
rule += ' "}" space'
|
||||||
|
|
||||||
return self._add_rule(rule_name, rule)
|
return rule
|
||||||
|
|
||||||
elif schema_type == "array" and "items" in schema:
|
|
||||||
# TODO `prefixItems` keyword
|
|
||||||
item_rule_name = self.visit(
|
|
||||||
schema["items"], f'{name}{"-" if name else ""}item'
|
|
||||||
)
|
|
||||||
list_item_operator = f'("," space {item_rule_name})'
|
|
||||||
successive_items = ""
|
|
||||||
min_items = schema.get("minItems", 0)
|
|
||||||
if min_items > 0:
|
|
||||||
first_item = f"({item_rule_name})"
|
|
||||||
successive_items = list_item_operator * (min_items - 1)
|
|
||||||
min_items -= 1
|
|
||||||
else:
|
|
||||||
first_item = f"({item_rule_name})?"
|
|
||||||
max_items = schema.get("maxItems")
|
|
||||||
if max_items is not None and max_items > min_items:
|
|
||||||
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
|
|
||||||
else:
|
|
||||||
successive_items += list_item_operator + "*"
|
|
||||||
rule = f'"[" space {first_item} {successive_items} "]" space'
|
|
||||||
return self._add_rule(rule_name, rule)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
|
||||||
return self._add_rule(
|
|
||||||
"root" if rule_name == "root" else schema_type,
|
|
||||||
PRIMITIVE_RULES[schema_type],
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_grammar(self):
|
def format_grammar(self):
|
||||||
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
|
return '\n'.join(
|
||||||
|
f'{name} ::= {rule}'
|
||||||
|
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||||
|
)
|
||||||
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||||
prop_order = prop_order or []
|
prop_order = prop_order or []
|
||||||
schema = json.loads(schema)
|
schema = json.loads(schema)
|
||||||
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||||
converter = SchemaConverter(prop_order)
|
converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
schema = converter.resolve_refs(schema, "stdin")
|
||||||
converter.visit(schema, "")
|
converter.visit(schema, "")
|
||||||
return converter.format_grammar()
|
return converter.format_grammar()
|
||||||
|
|
|
@ -59,6 +59,15 @@ def main():
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
raise ValueError(f"Config file {config_file} not found!")
|
raise ValueError(f"Config file {config_file} not found!")
|
||||||
with open(config_file, "rb") as f:
|
with open(config_file, "rb") as f:
|
||||||
|
# Check if yaml file
|
||||||
|
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
|
||||||
|
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||||
|
json.dumps(yaml.safe_load(f))
|
||||||
|
)
|
||||||
|
else:
|
||||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||||
model_settings = config_file_settings.models
|
model_settings = config_file_settings.models
|
||||||
|
|
|
@ -87,6 +87,13 @@ def get_llama_proxy():
|
||||||
llama_outer_lock.release()
|
llama_outer_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
_ping_message_factory = None
|
||||||
|
|
||||||
|
def set_ping_message_factory(factory):
|
||||||
|
global _ping_message_factory
|
||||||
|
_ping_message_factory = factory
|
||||||
|
|
||||||
|
|
||||||
def create_app(
|
def create_app(
|
||||||
settings: Settings | None = None,
|
settings: Settings | None = None,
|
||||||
server_settings: ServerSettings | None = None,
|
server_settings: ServerSettings | None = None,
|
||||||
|
@ -97,6 +104,14 @@ def create_app(
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
raise ValueError(f"Config file {config_file} not found!")
|
raise ValueError(f"Config file {config_file} not found!")
|
||||||
with open(config_file, "rb") as f:
|
with open(config_file, "rb") as f:
|
||||||
|
# Check if yaml file
|
||||||
|
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||||
|
json.dumps(yaml.safe_load(f))
|
||||||
|
)
|
||||||
|
else:
|
||||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||||
model_settings = config_file_settings.models
|
model_settings = config_file_settings.models
|
||||||
|
@ -130,6 +145,9 @@ def create_app(
|
||||||
assert model_settings is not None
|
assert model_settings is not None
|
||||||
set_llama_proxy(model_settings=model_settings)
|
set_llama_proxy(model_settings=model_settings)
|
||||||
|
|
||||||
|
if server_settings.disable_ping_events:
|
||||||
|
set_ping_message_factory(lambda: bytes())
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@ -294,6 +312,7 @@ async def create_completion(
|
||||||
iterator=iterator(),
|
iterator=iterator(),
|
||||||
),
|
),
|
||||||
sep="\n",
|
sep="\n",
|
||||||
|
ping_message_factory=_ping_message_factory,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return iterator_or_completion
|
return iterator_or_completion
|
||||||
|
@ -462,6 +481,7 @@ async def create_chat_completion(
|
||||||
iterator=iterator(),
|
iterator=iterator(),
|
||||||
),
|
),
|
||||||
sep="\n",
|
sep="\n",
|
||||||
|
ping_message_factory=_ping_message_factory,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return iterator_or_completion
|
return iterator_or_completion
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from typing import Optional, List, Literal, Union
|
from typing import Optional, List, Literal, Union
|
||||||
from pydantic import Field
|
from pydantic import Field, root_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
|
||||||
n_threads: int = Field(
|
n_threads: int = Field(
|
||||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||||
ge=1,
|
ge=1,
|
||||||
description="The number of threads to use.",
|
description="The number of threads to use. Use -1 for max cpu threads",
|
||||||
)
|
)
|
||||||
n_threads_batch: int = Field(
|
n_threads_batch: int = Field(
|
||||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
default=max(multiprocessing.cpu_count(), 1),
|
||||||
ge=0,
|
ge=0,
|
||||||
description="The number of threads to use when batch processing.",
|
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
|
||||||
)
|
)
|
||||||
rope_scaling_type: int = Field(
|
rope_scaling_type: int = Field(
|
||||||
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||||
|
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
|
||||||
default=True, description="Whether to print debug information."
|
default=True, description="Whether to print debug information."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
|
||||||
|
def set_dynamic_defaults(cls, values):
|
||||||
|
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
if values.get('n_threads', 0) == -1:
|
||||||
|
values['n_threads'] = cpu_count
|
||||||
|
if values.get('n_threads_batch', 0) == -1:
|
||||||
|
values['n_threads_batch'] = cpu_count
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class ServerSettings(BaseSettings):
|
class ServerSettings(BaseSettings):
|
||||||
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
||||||
|
@ -195,6 +205,10 @@ class ServerSettings(BaseSettings):
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to interrupt requests when a new request is received.",
|
description="Whether to interrupt requests when a new request is received.",
|
||||||
)
|
)
|
||||||
|
disable_ping_events: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Disable EventSource pings (may be needed for some clients).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(ServerSettings, ModelSettings):
|
class Settings(ServerSettings, ModelSettings):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["scikit-build-core[pyproject]>=0.5.1"]
|
requires = ["scikit-build-core[pyproject]>=0.9.2"]
|
||||||
build-backend = "scikit_build_core.build"
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
|
@ -35,6 +35,7 @@ server = [
|
||||||
"pydantic-settings>=2.0.1",
|
"pydantic-settings>=2.0.1",
|
||||||
"sse-starlette>=1.6.1",
|
"sse-starlette>=1.6.1",
|
||||||
"starlette-context>=0.3.6,<0.4",
|
"starlette-context>=0.3.6,<0.4",
|
||||||
|
"PyYAML>=5.1",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest>=7.4.0",
|
"pytest>=7.4.0",
|
||||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0
|
Subproject commit 4e96a812b3ce7322a29a3008db2ed73d9087b176
|
Loading…
Reference in a new issue