Merge branch 'main' into v0.2-wip

This commit is contained in:
Andrei Betlen 2023-08-24 00:30:51 -04:00
commit cf405f6764
11 changed files with 1730 additions and 330 deletions

View file

@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.1.78]
### Added
- Grammar based sampling via LlamaGrammar which can be passed to completions
- Make n_gpu_layers == -1 offload all layers
## [0.1.77] ## [0.1.77]
- (llama.cpp) Update llama.cpp add support for LLaMa 2 70B - (llama.cpp) Update llama.cpp add support for LLaMa 2 70B

View file

@ -140,7 +140,7 @@ llm = Llama(model_path="./models/7B/ggml-model.bin", n_ctx=2048)
Llama2 70b must set the `n_gqa` parameter (grouped-query attention factor) to 8 when loading: Llama2 70b must set the `n_gqa` parameter (grouped-query attention factor) to 8 when loading:
```python ```python
llm = Llama(model_path="./models/7B/ggml-model.bin", n_gqa=8) llm = Llama(model_path="./models/70B/ggml-model.bin", n_gqa=8)
``` ```
## Web Server ## Web Server
@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
## Low-level API ## Low-level API
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`. The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h). The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
Below is a short example demonstrating how to use the low-level API to tokenize a prompt: Below is a short example demonstrating how to use the low-level API to tokenize a prompt:

View file

@ -9,7 +9,7 @@ COPY . .
RUN apt update && apt install -y libopenblas-dev ninja-build build-essential RUN apt update && apt install -y libopenblas-dev ninja-build build-essential
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings
RUN LLAMA_OPENBLAS=1 pip install llama_cpp_python --verbose RUN CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama_cpp_python --verbose
# Run the server # Run the server
CMD python3 -m llama_cpp.server CMD python3 -m llama_cpp.server

View file

@ -23,10 +23,12 @@ import ctypes
from . import llama_cpp from . import llama_cpp
from .llama_types import * from .llama_types import *
from .llama_grammar import LlamaGrammar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from .utils import suppress_stdout_stderr
class BaseLlamaCache(ABC): class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model.""" """Base cache class for a llama.cpp model."""
@ -231,7 +233,8 @@ class Llama:
rope_freq_base: float = 10000.0, rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0, rope_freq_scale: float = 1.0,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY) rms_norm_eps: Optional[float] = None, # (TEMPORARY)
mul_mat_q: Optional[bool] = None, # (TEMPORARY)
verbose: bool = True, verbose: bool = True,
): ):
"""Load a llama.cpp model from `model_path`. """Load a llama.cpp model from `model_path`.
@ -241,6 +244,7 @@ class Llama:
n_ctx: Maximum context size. n_ctx: Maximum context size.
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
seed: Random seed. -1 for random. seed: Random seed. -1 for random.
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
f16_kv: Use half-precision for key/value cache. f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token. logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights. vocab_only: Only load the vocabulary no weights.
@ -269,7 +273,7 @@ class Llama:
self.params = llama_cpp.llama_context_default_params() self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx self.params.n_ctx = n_ctx
self.params.n_gpu_layers = n_gpu_layers self.params.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers
self.params.seed = seed self.params.seed = seed
self.params.f16_kv = f16_kv self.params.f16_kv = f16_kv
self.params.logits_all = logits_all self.params.logits_all = logits_all
@ -280,7 +284,7 @@ class Llama:
self.params.low_vram = low_vram self.params.low_vram = low_vram
self.tensor_split = tensor_split self.tensor_split = tensor_split
self._c_tensor_split = None self._p_tensor_split = None
if self.tensor_split is not None: if self.tensor_split is not None:
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
@ -299,6 +303,9 @@ class Llama:
if rms_norm_eps is not None: if rms_norm_eps is not None:
self.params.rms_norm_eps = rms_norm_eps self.params.rms_norm_eps = rms_norm_eps
if mul_mat_q is not None:
self.params.mul_mat_q = mul_mat_q
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
@ -316,12 +323,25 @@ class Llama:
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}") raise ValueError(f"Model path does not exist: {model_path}")
self.model = llama_cpp.llama_load_model_from_file( if verbose:
self.model_path.encode("utf-8"), self.params self.model = llama_cpp.llama_load_model_from_file(
) self.model_path.encode("utf-8"), self.params
)
else:
with suppress_stdout_stderr():
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.params
)
assert self.model is not None assert self.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) if verbose:
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
else:
with suppress_stdout_stderr():
print("here")
self.ctx = llama_cpp.llama_new_context_with_model(
self.model, self.params
)
assert self.ctx is not None assert self.ctx is not None
@ -358,8 +378,8 @@ class Llama:
sorted=sorted, sorted=sorted,
) )
self._candidates = candidates self._candidates = candidates
self._token_nl = Llama.token_nl() self._token_nl = self.token_nl()
self._token_eos = Llama.token_eos() self._token_eos = self.token_eos()
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
@ -437,10 +457,14 @@ class Llama:
""" """
assert self.ctx is not None assert self.ctx is not None
output = b"" output = b""
buffer_size = 32
buffer = (ctypes.c_char * buffer_size)()
for token in tokens: for token in tokens:
output += llama_cpp.llama_token_to_str( n = llama_cpp.llama_token_to_str(
self.ctx, llama_cpp.llama_token(token) self.ctx, llama_cpp.llama_token(token), buffer, buffer_size
) )
assert n <= buffer_size
output += bytes(buffer[:n])
return output return output
def set_cache(self, cache: Optional[BaseLlamaCache]): def set_cache(self, cache: Optional[BaseLlamaCache]):
@ -506,6 +530,7 @@ class Llama:
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
): ):
assert self.ctx is not None assert self.ctx is not None
assert self.n_tokens > 0 assert self.n_tokens > 0
@ -548,8 +573,16 @@ class Llama:
) )
if not penalize_nl: if not penalize_nl:
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if grammar is not None:
llama_cpp.llama_sample_grammar(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
grammar=grammar.grammar,
)
if temp.value == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( id = llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
@ -561,7 +594,7 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token_mirostat( id = llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau, tau=mirostat_tau,
@ -576,7 +609,7 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token_mirostat_v2( id = llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau, tau=mirostat_tau,
@ -613,10 +646,17 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token( id = llama_cpp.llama_sample_token(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
if grammar is not None:
llama_cpp.llama_grammar_accept_token(
ctx=self.ctx,
grammar=grammar.grammar,
token=llama_cpp.ctypes.c_int(id),
)
return id
def sample( def sample(
self, self,
@ -632,6 +672,7 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -665,6 +706,7 @@ class Llama:
mirostat_eta=llama_cpp.c_float(mirostat_eta), mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl, penalize_nl=penalize_nl,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
def generate( def generate(
@ -683,6 +725,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Generator[int, Optional[Sequence[int]], None]: ) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -704,7 +747,6 @@ class Llama:
The generated tokens. The generated tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
if reset and len(self._input_ids) > 0: if reset and len(self._input_ids) > 0:
longest_prefix = 0 longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]): for a, b in zip(self._input_ids, tokens[:-1]):
@ -722,6 +764,9 @@ class Llama:
if reset: if reset:
self.reset() self.reset()
if grammar is not None:
grammar.reset()
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( token = self.sample(
@ -736,6 +781,7 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :] self._input_ids, self._scores[-1, :]
@ -838,6 +884,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
@ -915,6 +962,7 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -965,9 +1013,7 @@ class Llama:
for token in remaining_tokens: for token in remaining_tokens:
token_end_position += len(self.detokenize([token])) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position >= ( if token_end_position >= (remaining_length - first_stop_position):
remaining_length - first_stop_position
):
break break
logprobs_or_none: Optional[CompletionLogprobs] = None logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
@ -1261,6 +1307,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1305,6 +1352,7 @@ class Llama:
model=model, model=model,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -1334,6 +1382,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1378,6 +1427,7 @@ class Llama:
model=model, model=model,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(
@ -1460,6 +1510,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -1502,6 +1553,7 @@ class Llama:
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
@ -1511,10 +1563,10 @@ class Llama:
return self._convert_text_completion_to_chat(completion) return self._convert_text_completion_to_chat(completion)
def __del__(self): def __del__(self):
if self.model is not None: if hasattr(self, "model") and self.model is not None:
llama_cpp.llama_free_model(self.model) llama_cpp.llama_free_model(self.model)
self.model = None self.model = None
if self.ctx is not None: if hasattr(self, "ctx") and self.ctx is not None:
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.ctx = None self.ctx = None
@ -1638,20 +1690,20 @@ class Llama:
assert self.ctx is not None assert self.ctx is not None
return LlamaTokenizer(self) return LlamaTokenizer(self)
@staticmethod def token_eos(self) -> int:
def token_eos() -> int:
"""Return the end-of-sequence token.""" """Return the end-of-sequence token."""
return llama_cpp.llama_token_eos() assert self.ctx is not None
return llama_cpp.llama_token_eos(self.ctx)
@staticmethod def token_bos(self) -> int:
def token_bos() -> int:
"""Return the beginning-of-sequence token.""" """Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos() assert self.ctx is not None
return llama_cpp.llama_token_bos(self.ctx)
@staticmethod def token_nl(self) -> int:
def token_nl() -> int:
"""Return the newline token.""" """Return the newline token."""
return llama_cpp.llama_token_nl() assert self.ctx is not None
return llama_cpp.llama_token_nl(self.ctx)
@staticmethod @staticmethod
def logits_to_logprobs(logits: List[float]) -> List[float]: def logits_to_logprobs(logits: List[float]) -> List[float]:

View file

@ -90,26 +90,17 @@ GGML_USE_CUBLAS = hasattr(_lib, "ggml_init_cublas")
GGML_CUDA_MAX_DEVICES = 16 GGML_CUDA_MAX_DEVICES = 16
LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1 LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1
# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' # define LLAMA_DEFAULT_SEED 0xFFFFFFFF
LLAMA_FILE_MAGIC_GGJT = 0x67676A74 LLAMA_DEFAULT_SEED = ctypes.c_int(0xFFFFFFFF)
# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
LLAMA_FILE_MAGIC_GGLA = 0x67676C61
# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
LLAMA_FILE_MAGIC_GGMF = 0x67676D66
# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
LLAMA_FILE_MAGIC_GGML = 0x67676D6C
# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
# #define LLAMA_FILE_VERSION 3 # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_VERSION = 3 LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E)
LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT
LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
LLAMA_SESSION_VERSION = 1 # define LLAMA_SESSION_VERSION 1
LLAMA_SESSION_VERSION = ctypes.c_int(1)
# #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
LLAMA_DEFAULT_SEED = 0xFFFFFFFF
# struct llama_model; # struct llama_model;
llama_model_p = c_void_p llama_model_p = c_void_p
@ -122,6 +113,82 @@ llama_context_p = c_void_p
llama_token = c_int llama_token = c_int
llama_token_p = POINTER(llama_token) llama_token_p = POINTER(llama_token)
# enum llama_log_level {
# LLAMA_LOG_LEVEL_ERROR = 2,
# LLAMA_LOG_LEVEL_WARN = 3,
# LLAMA_LOG_LEVEL_INFO = 4
# };
LLAMA_LOG_LEVEL_ERROR = c_int(2)
LLAMA_LOG_LEVEL_WARN = c_int(3)
LLAMA_LOG_LEVEL_INFO = c_int(4)
# enum llama_vocab_type {
# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
# };
LLAMA_VOCAB_TYPE_SPM = c_int(0)
LLAMA_VOCAB_TYPE_BPE = c_int(1)
# enum llama_token_type {
# LLAMA_TOKEN_TYPE_UNDEFINED = 0,
# LLAMA_TOKEN_TYPE_NORMAL = 1,
# LLAMA_TOKEN_TYPE_UNKNOWN = 2,
# LLAMA_TOKEN_TYPE_CONTROL = 3,
# LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
# LLAMA_TOKEN_TYPE_UNUSED = 5,
# LLAMA_TOKEN_TYPE_BYTE = 6,
# };
LLAMA_TOKEN_TYPE_UNDEFINED = c_int(0)
LLAMA_TOKEN_TYPE_NORMAL = c_int(1)
LLAMA_TOKEN_TYPE_UNKNOWN = c_int(2)
LLAMA_TOKEN_TYPE_CONTROL = c_int(3)
LLAMA_TOKEN_TYPE_USER_DEFINED = c_int(4)
LLAMA_TOKEN_TYPE_UNUSED = c_int(5)
LLAMA_TOKEN_TYPE_BYTE = c_int(6)
# enum llama_ftype {
# LLAMA_FTYPE_ALL_F32 = 0,
# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
#
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
# };
LLAMA_FTYPE_ALL_F32 = c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = c_int(1)
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2)
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3)
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4)
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7)
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8)
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9)
LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10)
LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11)
LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12)
LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13)
LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14)
LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15)
LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16)
LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17)
LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18)
LLAMA_FTYPE_GUESSED = c_int(1024)
# typedef struct llama_token_data { # typedef struct llama_token_data {
# llama_token id; // token id # llama_token id; // token id
@ -157,16 +224,13 @@ llama_token_data_array_p = POINTER(llama_token_data_array)
# typedef void (*llama_progress_callback)(float progress, void *ctx); # typedef void (*llama_progress_callback)(float progress, void *ctx);
llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# struct llama_context_params { # struct llama_context_params {
# uint32_t seed; // RNG seed, -1 for random # uint32_t seed; // RNG seed, -1 for random
# int32_t n_ctx; // text context # int32_t n_ctx; // text context
# int32_t n_batch; // prompt processing batch size # int32_t n_batch; // prompt processing batch size
# int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
# float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
# int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t n_gpu_layers; // number of layers to store in VRAM
# int32_t main_gpu; // the GPU that is used for scratch and small tensors # int32_t main_gpu; // the GPU that is used for scratch and small tensors
#
# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054 # // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@ -181,6 +245,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# // Keep the booleans together to avoid misalignment during copy-by-value. # // Keep the booleans together to avoid misalignment during copy-by-value.
# bool low_vram; // if true, reduce VRAM usage at the cost of performance # bool low_vram; // if true, reduce VRAM usage at the cost of performance
# bool mul_mat_q; // if true, use experimental mul_mat_q kernels
# bool f16_kv; // use fp16 for KV cache # bool f16_kv; // use fp16 for KV cache
# bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool logits_all; // the llama_eval() call computes all logits, not just the last one
# bool vocab_only; // only load the vocabulary, no weights # bool vocab_only; // only load the vocabulary, no weights
@ -193,16 +258,15 @@ class llama_context_params(Structure):
("seed", c_uint32), ("seed", c_uint32),
("n_ctx", c_int32), ("n_ctx", c_int32),
("n_batch", c_int32), ("n_batch", c_int32),
("n_gqa", c_int32),
("rms_norm_eps", c_float),
("n_gpu_layers", c_int32), ("n_gpu_layers", c_int32),
("main_gpu", c_int32), ("main_gpu", c_int32),
("tensor_split", POINTER(c_float)), ("tensor_split", c_float_p),
("rope_freq_base", c_float), ("rope_freq_base", c_float),
("rope_freq_scale", c_float), ("rope_freq_scale", c_float),
("progress_callback", llama_progress_callback), ("progress_callback", llama_progress_callback),
("progress_callback_user_data", c_void_p), ("progress_callback_user_data", c_void_p),
("low_vram", c_bool), ("low_vram", c_bool),
("mul_mat_q", c_bool),
("f16_kv", c_bool), ("f16_kv", c_bool),
("logits_all", c_bool), ("logits_all", c_bool),
("vocab_only", c_bool), ("vocab_only", c_bool),
@ -214,50 +278,20 @@ class llama_context_params(Structure):
llama_context_params_p = POINTER(llama_context_params) llama_context_params_p = POINTER(llama_context_params)
# enum llama_ftype {
# LLAMA_FTYPE_ALL_F32 = 0, # // Signature for logging events
# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors # // Note that text includes the new line character at the end for most events.
# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors # // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors # // if it exists.
# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 # // It might not exist for progress report where '.' is output repeatedly.
# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed # typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p)
# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
# };
LLAMA_FTYPE_ALL_F32 = 0
LLAMA_FTYPE_MOSTLY_F16 = 1
LLAMA_FTYPE_MOSTLY_Q4_0 = 2
LLAMA_FTYPE_MOSTLY_Q4_1 = 3
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
LLAMA_FTYPE_MOSTLY_Q8_0 = 7
LLAMA_FTYPE_MOSTLY_Q5_0 = 8
LLAMA_FTYPE_MOSTLY_Q5_1 = 9
LLAMA_FTYPE_MOSTLY_Q2_K = 10
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
LLAMA_FTYPE_MOSTLY_Q6_K = 18
# // model quantization parameters # // model quantization parameters
# typedef struct llama_model_quantize_params { # typedef struct llama_model_quantize_params {
# int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() # int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
# enum llama_ftype ftype; // quantize to this llama_ftype # enum llama_ftype ftype; // quantize to this llama_ftype
# bool allow_requantize; // allow quantizing non-f32/f16 tensors # bool allow_requantize; // allow quantizing non-f32/f16 tensors
# bool quantize_output_tensor; // quantize output.weight # bool quantize_output_tensor; // quantize output.weight
# } llama_model_quantize_params; # } llama_model_quantize_params;
@ -349,16 +383,7 @@ class llama_timings(Structure):
] ]
# LLAMA_API int llama_max_devices(); # LLAMA_API struct llama_context_params llama_context_default_params(void);
def llama_max_devices() -> int:
return _lib.llama_max_devices()
_lib.llama_max_devices.argtypes = []
_lib.llama_max_devices.restype = c_int
# LLAMA_API struct llama_context_params llama_context_default_params();
def llama_context_default_params() -> llama_context_params: def llama_context_default_params() -> llama_context_params:
return _lib.llama_context_default_params() return _lib.llama_context_default_params()
@ -367,7 +392,7 @@ _lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params _lib.llama_context_default_params.restype = llama_context_params
# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
def llama_model_quantize_default_params() -> llama_model_quantize_params: def llama_model_quantize_default_params() -> llama_model_quantize_params:
return _lib.llama_model_quantize_default_params() return _lib.llama_model_quantize_default_params()
@ -376,25 +401,6 @@ _lib.llama_model_quantize_default_params.argtypes = []
_lib.llama_model_quantize_default_params.restype = llama_model_quantize_params _lib.llama_model_quantize_default_params.restype = llama_model_quantize_params
# LLAMA_API bool llama_mmap_supported();
def llama_mmap_supported() -> bool:
return _lib.llama_mmap_supported()
_lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool
# LLAMA_API bool llama_mlock_supported();
def llama_mlock_supported() -> bool:
return _lib.llama_mlock_supported()
_lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# // TODO: not great API - very likely to change
# // Initialize the llama + ggml backend # // Initialize the llama + ggml backend
# // If numa is true, use NUMA optimizations # // If numa is true, use NUMA optimizations
# // Call once at the start of the program # // Call once at the start of the program
@ -408,7 +414,7 @@ _lib.llama_backend_init.restype = None
# // Call once at the end of the program - currently only used for MPI # // Call once at the end of the program - currently only used for MPI
# LLAMA_API void llama_backend_free(); # LLAMA_API void llama_backend_free(void);
def llama_backend_free(): def llama_backend_free():
return _lib.llama_backend_free() return _lib.llama_backend_free()
@ -418,7 +424,7 @@ _lib.llama_backend_free.restype = None
# LLAMA_API struct llama_model * llama_load_model_from_file( # LLAMA_API struct llama_model * llama_load_model_from_file(
# const char * path_model, # const char * path_model,
# struct llama_context_params params); # struct llama_context_params params);
def llama_load_model_from_file( def llama_load_model_from_file(
path_model: bytes, params: llama_context_params path_model: bytes, params: llama_context_params
@ -440,7 +446,7 @@ _lib.llama_free_model.restype = None
# LLAMA_API struct llama_context * llama_new_context_with_model( # LLAMA_API struct llama_context * llama_new_context_with_model(
# struct llama_model * model, # struct llama_model * model,
# struct llama_context_params params); # struct llama_context_params params);
def llama_new_context_with_model( def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params model: llama_model_p, params: llama_context_params
@ -452,7 +458,17 @@ _lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_param
_lib.llama_new_context_with_model.restype = llama_context_p _lib.llama_new_context_with_model.restype = llama_context_p
# LLAMA_API int64_t llama_time_us(); # // Frees all allocated memory
# LLAMA_API void llama_free(struct llama_context * ctx);
def llama_free(ctx: llama_context_p):
return _lib.llama_free(ctx)
_lib.llama_free.argtypes = [llama_context_p]
_lib.llama_free.restype = None
# LLAMA_API int64_t llama_time_us(void);
def llama_time_us() -> int: def llama_time_us() -> int:
return _lib.llama_time_us() return _lib.llama_time_us()
@ -461,30 +477,95 @@ _lib.llama_time_us.argtypes = []
_lib.llama_time_us.restype = ctypes.c_int64 _lib.llama_time_us.restype = ctypes.c_int64
# // Various functions for loading a ggml llama model. # LLAMA_API int llama_max_devices (void);
# // Allocate (almost) all memory needed for the model. def llama_max_devices() -> int:
# // Return NULL on failure return _lib.llama_max_devices()
# LLAMA_API struct llama_context * llama_init_from_file(
# const char * path_model,
# struct llama_context_params params);
def llama_init_from_file(
path_model: bytes, params: llama_context_params
) -> llama_context_p:
return _lib.llama_init_from_file(path_model, params)
_lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] _lib.llama_max_devices.argtypes = []
_lib.llama_init_from_file.restype = llama_context_p _lib.llama_max_devices.restype = c_int
# Frees all allocated memory # LLAMA_API bool llama_mmap_supported (void);
# LLAMA_API void llama_free(struct llama_context * ctx); def llama_mmap_supported() -> bool:
def llama_free(ctx: llama_context_p): return _lib.llama_mmap_supported()
return _lib.llama_free(ctx)
_lib.llama_free.argtypes = [llama_context_p] _lib.llama_mmap_supported.argtypes = []
_lib.llama_free.restype = None _lib.llama_mmap_supported.restype = c_bool
# LLAMA_API bool llama_mlock_supported(void);
def llama_mlock_supported() -> bool:
return _lib.llama_mlock_supported()
_lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx)
_lib.llama_n_vocab.argtypes = [llama_context_p]
_lib.llama_n_vocab.restype = c_int
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx)
_lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx)
_lib.llama_n_embd.argtypes = [llama_context_p]
_lib.llama_n_embd.restype = c_int
# LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
def llama_model_n_vocab(model: llama_model_p) -> int:
return _lib.llama_model_n_vocab(model)
_lib.llama_model_n_vocab.argtypes = [llama_model_p]
_lib.llama_model_n_vocab.restype = c_int
# LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
def llama_model_n_ctx(model: llama_model_p) -> int:
return _lib.llama_model_n_ctx(model)
_lib.llama_model_n_ctx.argtypes = [llama_model_p]
_lib.llama_model_n_ctx.restype = c_int
# LLAMA_API int llama_model_n_embd (const struct llama_model * model);
def llama_model_n_embd(model: llama_model_p) -> int:
return _lib.llama_model_n_embd(model)
_lib.llama_model_n_embd.argtypes = [llama_model_p]
_lib.llama_model_n_embd.restype = c_int
# // Get a string describing the model type
# LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size);
def llama_model_type(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int:
return _lib.llama_model_type(model, buf, buf_size)
_lib.llama_model_type.argtypes = [llama_model_p, c_char_p, c_size_t]
_lib.llama_model_type.restype = c_int
# // Returns 0 on success # // Returns 0 on success
@ -703,147 +784,17 @@ _lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int, c_int
_lib.llama_eval_embd.restype = c_int _lib.llama_eval_embd.restype = c_int
# Convert the provided text into tokens. # // Export a static computation graph for context of 511 and batch size of 1
# The tokens pointer must be large enough to hold the resulting tokens. # // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
# Returns the number of tokens on success, no more than n_max_tokens # // parameters here to keep things simple
# Returns a negative number on failure - the number of tokens that would have been returned # // IMPORTANT: do not use for anything else other than debugging and testing!
# TODO: not sure if correct # LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
# LLAMA_API int llama_tokenize( def llama_eval_export(ctx: llama_context_p, fname: bytes) -> int:
# struct llama_context * ctx, return _lib.llama_eval_export(ctx, fname)
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: Union[c_int, int],
add_bos: Union[c_bool, bool],
) -> int:
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] _lib.llama_eval_export.argtypes = [llama_context_p, c_char_p]
_lib.llama_tokenize.restype = c_int _lib.llama_eval_export.restype = c_int
# LLAMA_API int llama_tokenize_with_model(
# const struct llama_model * model,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize_with_model(
model: llama_model_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: Union[c_int, int],
add_bos: Union[c_bool, bool],
) -> int:
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx)
_lib.llama_n_vocab.argtypes = [llama_context_p]
_lib.llama_n_vocab.restype = c_int
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx)
_lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx)
_lib.llama_n_embd.argtypes = [llama_context_p]
_lib.llama_n_embd.restype = c_int
# LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
def llama_n_vocab_from_model(model: llama_model_p) -> int:
return _lib.llama_n_vocab_from_model(model)
_lib.llama_n_vocab_from_model.argtypes = [llama_model_p]
_lib.llama_n_vocab_from_model.restype = c_int
# LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model);
def llama_n_ctx_from_model(model: llama_model_p) -> int:
return _lib.llama_n_ctx_from_model(model)
_lib.llama_n_ctx_from_model.argtypes = [llama_model_p]
_lib.llama_n_ctx_from_model.restype = c_int
# LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
def llama_n_embd_from_model(model: llama_model_p) -> int:
return _lib.llama_n_embd_from_model(model)
_lib.llama_n_embd_from_model.argtypes = [llama_model_p]
_lib.llama_n_embd_from_model.restype = c_int
# // Get the vocabulary as output parameters.
# // Returns number of results.
# LLAMA_API int llama_get_vocab(
# const struct llama_context * ctx,
# const char * * strings,
# float * scores,
# int capacity);
def llama_get_vocab(
ctx: llama_context_p,
strings, # type: Array[c_char_p] # type: ignore
scores, # type: Array[c_float] # type: ignore
capacity: Union[c_int, int],
) -> int:
return _lib.llama_get_vocab(ctx, strings, scores, capacity)
_lib.llama_get_vocab.argtypes = [
llama_context_p,
POINTER(c_char_p),
POINTER(c_float),
c_int,
]
_lib.llama_get_vocab.restype = c_int
# LLAMA_API int llama_get_vocab_from_model(
# const struct llama_model * model,
# const char * * strings,
# float * scores,
# int capacity);
def llama_get_vocab_from_model(
model: llama_model_p,
strings, # type: Array[c_char_p] # type: ignore
scores, # type: Array[c_float] # type: ignore
capacity: Union[c_int, int],
) -> int:
return _lib.llama_get_vocab_from_model(model, strings, scores, capacity)
_lib.llama_get_vocab_from_model.argtypes = [
llama_model_p,
POINTER(c_char_p),
POINTER(c_float),
c_int,
]
_lib.llama_get_vocab_from_model.restype = c_int
# Token logits obtained from the last call to llama_eval() # Token logits obtained from the last call to llama_eval()
@ -875,16 +826,186 @@ _lib.llama_get_embeddings.argtypes = [llama_context_p]
_lib.llama_get_embeddings.restype = c_float_p _lib.llama_get_embeddings.restype = c_float_p
# //
# // Vocab
# //
# LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
def llama_token_get_text(ctx: llama_context_p, token: llama_token) -> bytes:
return _lib.llama_token_get_text(ctx, token)
_lib.llama_token_get_text.argtypes = [llama_context_p, llama_token]
_lib.llama_token_get_text.restype = c_char_p
# LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
def llama_token_get_score(ctx: llama_context_p, token: llama_token) -> float:
return _lib.llama_token_get_score(ctx, token)
_lib.llama_token_get_score.argtypes = [llama_context_p, llama_token]
_lib.llama_token_get_score.restype = c_float
# LLAMA_API llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
def llama_token_get_type(ctx: llama_context_p, token: llama_token) -> int:
return _lib.llama_token_get_type(ctx, token)
_lib.llama_token_get_type.argtypes = [llama_context_p, llama_token]
_lib.llama_token_get_type.restype = ctypes.c_int
# // Special tokens
# LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
def llama_token_bos(ctx: llama_context_p) -> llama_token:
return _lib.llama_token_bos(ctx)
_lib.llama_token_bos.argtypes = [llama_context_p]
_lib.llama_token_bos.restype = llama_token
# LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
def llama_token_eos(ctx: llama_context_p) -> llama_token:
return _lib.llama_token_eos(ctx)
_lib.llama_token_eos.argtypes = [llama_context_p]
_lib.llama_token_eos.restype = llama_token
# LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
def llama_token_nl(ctx: llama_context_p) -> llama_token:
return _lib.llama_token_nl(ctx)
_lib.llama_token_nl.argtypes = [llama_context_p]
_lib.llama_token_nl.restype = llama_token
# //
# // Tokenization
# //
# Convert the provided text into tokens.
# The tokens pointer must be large enough to hold the resulting tokens.
# Returns the number of tokens on success, no more than n_max_tokens
# Returns a negative number on failure - the number of tokens that would have been returned
# TODO: not sure if correct
# LLAMA_API int llama_tokenize(
# struct llama_context * ctx,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: c_int,
add_bos: c_bool,
) -> int:
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool]
_lib.llama_tokenize.restype = c_int
# LLAMA_API int llama_tokenize_bpe(
# struct llama_context * ctx,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize_bpe(
ctx: llama_context_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: c_int,
add_bos: c_bool,
) -> int:
return _lib.llama_tokenize_bpe(ctx, text, tokens, n_max_tokens, add_bos)
_lib.llama_tokenize_bpe.argtypes = [
llama_context_p,
c_char_p,
llama_token_p,
c_int,
c_bool,
]
_lib.llama_tokenize_bpe.restype = c_int
# LLAMA_API int llama_tokenize_with_model(
# const struct llama_model * model,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize_with_model(
model: llama_model_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: c_int,
add_bos: c_bool,
) -> int:
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
_lib.llama_tokenize_with_model.argtypes = [
llama_model_p,
c_char_p,
llama_token_p,
c_int,
c_bool,
]
_lib.llama_tokenize_with_model.restype = c_int
# // Token Id -> String. Uses the vocabulary in the provided context # // Token Id -> String. Uses the vocabulary in the provided context
# LLAMA_API const char * llama_token_to_str( # // Does not write null terminator to the buffer
# LLAMA_API int llama_token_to_str(
# const struct llama_context * ctx, # const struct llama_context * ctx,
# llama_token token); # llama_token token,
def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes: # char * buf,
return _lib.llama_token_to_str(ctx, token) # int length);
def llama_token_to_str(
ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int
) -> int:
return _lib.llama_token_to_str(ctx, token, buf, length)
_lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] _lib.llama_tokenize_with_model.argtypes = [
_lib.llama_token_to_str.restype = c_char_p llama_model_p,
c_char_p,
llama_token_p,
c_int,
c_bool,
]
_lib.llama_tokenize_with_model.restype = c_int
# LLAMA_API int llama_token_to_str_bpe(
# const struct llama_context * ctx,
# llama_token token,
# char * buf,
# int length);
def llama_token_to_str_bpe(
ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int
) -> int:
return _lib.llama_token_to_str_bpe(ctx, token, buf, length)
_lib.llama_token_to_str_bpe.argtypes = [llama_context_p, llama_token, c_char_p, c_int]
_lib.llama_token_to_str_bpe.restype = c_int
# LLAMA_API const char * llama_token_to_str_with_model( # LLAMA_API const char * llama_token_to_str_with_model(
@ -897,38 +1018,12 @@ def llama_token_to_str_with_model(model: llama_model_p, token: llama_token) -> b
_lib.llama_token_to_str_with_model.argtypes = [llama_model_p, llama_token] _lib.llama_token_to_str_with_model.argtypes = [llama_model_p, llama_token]
_lib.llama_token_to_str_with_model.restype = c_char_p _lib.llama_token_to_str_with_model.restype = c_char_p
# Special tokens
# LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence
def llama_token_bos() -> int:
return _lib.llama_token_bos()
_lib.llama_token_bos.argtypes = []
_lib.llama_token_bos.restype = llama_token
# LLAMA_API llama_token llama_token_eos(); // end-of-sentence
def llama_token_eos() -> int:
return _lib.llama_token_eos()
_lib.llama_token_eos.argtypes = []
_lib.llama_token_eos.restype = llama_token
# LLAMA_API llama_token llama_token_nl(); // next-line
def llama_token_nl() -> int:
return _lib.llama_token_nl()
_lib.llama_token_nl.argtypes = []
_lib.llama_token_nl.restype = llama_token
# //
# // Grammar # // Grammar
# // # //
# LLAMA_API struct llama_grammar * llama_grammar_init( # LLAMA_API struct llama_grammar * llama_grammar_init(
# const llama_grammar_element ** rules, # const llama_grammar_element ** rules,
# size_t n_rules, # size_t n_rules,
@ -958,7 +1053,9 @@ _lib.llama_grammar_free.argtypes = [llama_grammar_p]
_lib.llama_grammar_free.restype = None _lib.llama_grammar_free.restype = None
# Sampling functions # //
# // Sampling functions
# //
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@ -1157,12 +1254,11 @@ _lib.llama_sample_temperature.argtypes = [
_lib.llama_sample_temperature.restype = None _lib.llama_sample_temperature.restype = None
# /// @details Apply constraints from grammar
# LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); # LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
def llama_sample_grammar( def llama_sample_grammar(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates, # type: _Pointer[llama_token_data_array]
grammar: llama_grammar_p, grammar, # type: llama_grammar_p
): ):
return _lib.llama_sample_grammar(ctx, candidates, grammar) return _lib.llama_sample_grammar(ctx, candidates, grammar)
@ -1265,9 +1361,11 @@ _lib.llama_sample_token.restype = llama_token
# /// @details Accepts the sampled token into the grammar # /// @details Accepts the sampled token into the grammar
# LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); # LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
def llama_grammar_accept_token( def llama_grammar_accept_token(
ctx: llama_context_p, grammar: llama_grammar_p, token: llama_token ctx: llama_context_p,
): grammar: llama_grammar_p,
return _lib.llama_grammar_accept_token(ctx, grammar, token) token: llama_token,
) -> None:
_lib.llama_grammar_accept_token(ctx, grammar, token)
_lib.llama_grammar_accept_token.argtypes = [ _lib.llama_grammar_accept_token.argtypes = [
@ -1316,6 +1414,19 @@ def llama_print_system_info() -> bytes:
_lib.llama_print_system_info.argtypes = [] _lib.llama_print_system_info.argtypes = []
_lib.llama_print_system_info.restype = c_char_p _lib.llama_print_system_info.restype = c_char_p
# // Set callback for all future logging events.
# // If this is not called, or NULL is supplied, everything is output on stderr.
# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data);
def llama_log_set(
log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore
):
return _lib.llama_log_set(log_callback, user_data)
_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p]
_lib.llama_log_set.restype = None
################################################################################################### ###################################################################################################

1188
llama_cpp/llama_grammar.py Normal file

File diff suppressed because it is too large Load diff

0
llama_cpp/py.typed Normal file
View file

View file

@ -108,6 +108,10 @@ class Settings(BaseSettings):
default=None, default=None,
description="TEMPORARY", description="TEMPORARY",
) )
mul_mat_q: Optional[bool] = Field(
default=None,
description="TEMPORARY",
)
class ErrorResponse(TypedDict): class ErrorResponse(TypedDict):

38
llama_cpp/utils.py Normal file
View file

@ -0,0 +1,38 @@
import os
import sys
class suppress_stdout_stderr(object):
# Oddly enough this works better than the contextlib version
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()

View file

@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "llama_cpp_python" name = "llama_cpp_python"
version = "0.1.77" version = "0.1.78"
description = "Python bindings for the llama.cpp library" description = "Python bindings for the llama.cpp library"
readme = "README.md" readme = "README.md"
license = { text = "MIT" } license = { text = "MIT" }

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 41c674161fb2459bdf7806d1eebead15bc5d046e Subproject commit f5fe98d11bdf9e7797bcfb05c0c3601ffc4b9d26