320a5d7ea5
* feat: add explicit methods to free model This commit introduces a `close` method to both `Llama` and `_LlamaModel`, allowing users to explicitly free the model from RAM/VRAM. The previous implementation relied on the destructor of `_LlamaModel` to free the model. However, in Python, the timing of destructor calls is unclear—for instance, the `del` statement does not guarantee immediate invocation of the destructor. This commit provides an explicit method to release the model, which works immediately and allows the user to load another model without memory issues. Additionally, this commit implements a context manager in the `Llama` class, enabling the automatic closure of the `Llama` object when used with the `with` statement. * feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch This commit enables automatic resource management by implementing the `ContextManager` protocol in `_LlamaModel`, `_LlamaContext`, and `_LlamaBatch`. This ensures that resources are properly managed and released within a `with` statement, enhancing robustness and safety in resource handling. * feat: add ExitStack for Llama's internal class closure This update implements ExitStack to manage and close internal classes in Llama, enhancing efficient and safe resource management. * Use contextlib ExitStack and closing * Explicitly free model when closing resources on server --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
835 lines
27 KiB
Python
835 lines
27 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import ctypes
|
|
|
|
from typing import (
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
)
|
|
from dataclasses import dataclass, field
|
|
from contextlib import ExitStack
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
from .llama_types import *
|
|
from .llama_grammar import LlamaGrammar
|
|
from ._utils import suppress_stdout_stderr
|
|
|
|
import llama_cpp.llama_cpp as llama_cpp
|
|
|
|
|
|
# Python wrappers over llama.h structs
|
|
|
|
|
|
class _LlamaModel:
|
|
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
path_model: str,
|
|
params: llama_cpp.llama_model_params,
|
|
verbose: bool = True,
|
|
):
|
|
self.path_model = path_model
|
|
self.params = params
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
self.model = None
|
|
|
|
if not os.path.exists(path_model):
|
|
raise ValueError(f"Model path does not exist: {path_model}")
|
|
|
|
with suppress_stdout_stderr(disable=verbose):
|
|
self.model = llama_cpp.llama_load_model_from_file(
|
|
self.path_model.encode("utf-8"), self.params
|
|
)
|
|
|
|
if self.model is None:
|
|
raise ValueError(f"Failed to load model from file: {path_model}")
|
|
|
|
def free_model():
|
|
if self.model is None:
|
|
return
|
|
llama_cpp.llama_free_model(self.model)
|
|
self.model = None
|
|
|
|
self._exit_stack.callback(free_model)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def vocab_type(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_vocab_type(self.model)
|
|
|
|
def n_vocab(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_n_vocab(self.model)
|
|
|
|
def n_ctx_train(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_n_ctx_train(self.model)
|
|
|
|
def n_embd(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_n_embd(self.model)
|
|
|
|
def rope_freq_scale_train(self) -> float:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_rope_freq_scale_train(self.model)
|
|
|
|
def desc(self) -> str:
|
|
assert self.model is not None
|
|
buf = ctypes.create_string_buffer(1024)
|
|
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
|
return buf.value.decode("utf-8")
|
|
|
|
def size(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_model_size(self.model)
|
|
|
|
def n_params(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_model_n_params(self.model)
|
|
|
|
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
|
|
|
|
def apply_lora_from_file(
|
|
self,
|
|
lora_path: str,
|
|
scale: float,
|
|
path_base_model: Optional[str],
|
|
n_threads: int,
|
|
):
|
|
assert self.model is not None
|
|
return llama_cpp.llama_model_apply_lora_from_file(
|
|
self.model,
|
|
lora_path.encode("utf-8"),
|
|
scale,
|
|
path_base_model.encode("utf-8")
|
|
if path_base_model is not None
|
|
else ctypes.c_char_p(0),
|
|
n_threads,
|
|
)
|
|
|
|
# Vocab
|
|
|
|
def token_get_text(self, token: int) -> str:
|
|
# TODO: Fix
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
|
|
|
|
def token_get_score(self, token: int) -> float:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_get_score(self.model, token)
|
|
|
|
def token_get_attr(self, token: int) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_get_attr(self.model, token)
|
|
|
|
# Special tokens
|
|
|
|
def token_bos(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_bos(self.model)
|
|
|
|
def token_eos(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_eos(self.model)
|
|
|
|
def token_cls(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_cls(self.model)
|
|
|
|
def token_sep(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_sep(self.model)
|
|
|
|
def token_nl(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_nl(self.model)
|
|
|
|
def token_prefix(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_prefix(self.model)
|
|
|
|
def token_middle(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_middle(self.model)
|
|
|
|
def token_suffix(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_suffix(self.model)
|
|
|
|
def token_eot(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_token_eot(self.model)
|
|
|
|
def add_bos_token(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_add_bos_token(self.model)
|
|
|
|
def add_eos_token(self) -> int:
|
|
assert self.model is not None
|
|
return llama_cpp.llama_add_eos_token(self.model)
|
|
|
|
# Tokenization
|
|
|
|
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
|
assert self.model is not None
|
|
n_ctx = self.n_ctx_train()
|
|
tokens = (llama_cpp.llama_token * n_ctx)()
|
|
n_tokens = llama_cpp.llama_tokenize(
|
|
self.model, text, len(text), tokens, n_ctx, add_bos, special
|
|
)
|
|
if n_tokens < 0:
|
|
n_tokens = abs(n_tokens)
|
|
tokens = (llama_cpp.llama_token * n_tokens)()
|
|
n_tokens = llama_cpp.llama_tokenize(
|
|
self.model, text, len(text), tokens, n_tokens, add_bos, special
|
|
)
|
|
if n_tokens < 0:
|
|
raise RuntimeError(
|
|
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
|
)
|
|
return list(tokens[:n_tokens])
|
|
|
|
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
|
assert self.model is not None
|
|
buf = ctypes.create_string_buffer(32)
|
|
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special)
|
|
return bytes(buf)
|
|
|
|
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
|
assert self.model is not None
|
|
output = b""
|
|
size = 32
|
|
buffer = (ctypes.c_char * size)()
|
|
for token in tokens:
|
|
n = llama_cpp.llama_token_to_piece(
|
|
self.model, llama_cpp.llama_token(token), buffer, size, special
|
|
)
|
|
assert n <= size
|
|
output += bytes(buffer[:n])
|
|
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
|
# this line removes a leading space if the first token is a beginning of sentence token
|
|
return (
|
|
output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b' ' else output
|
|
)
|
|
|
|
# Extra
|
|
def metadata(self) -> Dict[str, str]:
|
|
assert self.model is not None
|
|
metadata: Dict[str, str] = {}
|
|
buffer_size = 1024
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
# zero the buffer
|
|
buffer.value = b'\0' * buffer_size
|
|
# iterate over model keys
|
|
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
|
if nbytes > buffer_size:
|
|
buffer_size = nbytes + 1
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
|
key = buffer.value.decode("utf-8")
|
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
|
if nbytes > buffer_size:
|
|
buffer_size = nbytes + 1
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
|
value = buffer.value.decode("utf-8")
|
|
metadata[key] = value
|
|
return metadata
|
|
|
|
@staticmethod
|
|
def default_params():
|
|
"""Get the default llama_model_params."""
|
|
return llama_cpp.llama_model_default_params()
|
|
|
|
|
|
class _LlamaContext:
|
|
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
model: _LlamaModel,
|
|
params: llama_cpp.llama_context_params,
|
|
verbose: bool = True,
|
|
):
|
|
self.model = model
|
|
self.params = params
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
self.ctx = None
|
|
|
|
assert self.model.model is not None
|
|
|
|
self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
|
|
|
|
if self.ctx is None:
|
|
raise ValueError("Failed to create llama_context")
|
|
|
|
def free_ctx():
|
|
if self.ctx is None:
|
|
return
|
|
llama_cpp.llama_free(self.ctx)
|
|
self.ctx = None
|
|
|
|
self._exit_stack.callback(free_ctx)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def n_ctx(self) -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_n_ctx(self.ctx)
|
|
|
|
def pooling_type(self) -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_pooling_type(self.ctx)
|
|
|
|
def kv_cache_clear(self):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_kv_cache_clear(self.ctx)
|
|
|
|
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
|
|
|
|
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
|
|
|
|
def kv_cache_seq_keep(self, seq_id: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
|
|
|
|
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
|
|
|
|
def get_state_size(self) -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_get_state_size(self.ctx)
|
|
|
|
# TODO: copy_state_data
|
|
|
|
# TODO: set_state_data
|
|
|
|
# TODO: llama_load_session_file
|
|
|
|
# TODO: llama_save_session_file
|
|
|
|
def decode(self, batch: "_LlamaBatch"):
|
|
assert self.ctx is not None
|
|
assert batch.batch is not None
|
|
return_code = llama_cpp.llama_decode(
|
|
self.ctx,
|
|
batch.batch,
|
|
)
|
|
if return_code != 0:
|
|
raise RuntimeError(f"llama_decode returned {return_code}")
|
|
|
|
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
|
|
|
def get_logits(self):
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_get_logits(self.ctx)
|
|
|
|
def get_logits_ith(self, i: int):
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
|
|
|
def get_embeddings(self):
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_get_embeddings(self.ctx)
|
|
|
|
# Sampling functions
|
|
|
|
def set_rng_seed(self, seed: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_set_rng_seed(self.ctx, seed)
|
|
|
|
def sample_repetition_penalties(
|
|
self,
|
|
candidates: "_LlamaTokenDataArray",
|
|
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
|
penalty_last_n: int,
|
|
penalty_repeat: float,
|
|
penalty_freq: float,
|
|
penalty_present: float,
|
|
):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_repetition_penalties(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
last_tokens_data,
|
|
penalty_last_n,
|
|
penalty_repeat,
|
|
penalty_freq,
|
|
penalty_present,
|
|
)
|
|
|
|
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_softmax(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
)
|
|
|
|
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_top_k(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
|
|
)
|
|
|
|
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_top_p(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
|
)
|
|
|
|
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_min_p(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
|
)
|
|
|
|
def sample_tail_free(
|
|
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
|
|
):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_tail_free(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
|
|
)
|
|
|
|
def sample_typical(
|
|
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
|
):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_typical(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
|
)
|
|
|
|
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_sample_temp(
|
|
self.ctx, llama_cpp.byref(candidates.candidates), temp
|
|
)
|
|
|
|
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
|
assert self.ctx is not None
|
|
assert grammar.grammar is not None
|
|
llama_cpp.llama_sample_grammar(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
grammar.grammar,
|
|
)
|
|
|
|
def sample_token_mirostat(
|
|
self,
|
|
candidates: "_LlamaTokenDataArray",
|
|
tau: float,
|
|
eta: float,
|
|
m: int,
|
|
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
|
) -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_sample_token_mirostat(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
tau,
|
|
eta,
|
|
m,
|
|
mu,
|
|
)
|
|
|
|
def sample_token_mirostat_v2(
|
|
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
|
|
) -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_sample_token_mirostat_v2(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
tau,
|
|
eta,
|
|
mu,
|
|
)
|
|
|
|
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_sample_token_greedy(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
)
|
|
|
|
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
assert self.ctx is not None
|
|
return llama_cpp.llama_sample_token(
|
|
self.ctx,
|
|
llama_cpp.byref(candidates.candidates),
|
|
)
|
|
|
|
# Grammar
|
|
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
|
assert self.ctx is not None
|
|
assert grammar.grammar is not None
|
|
llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token)
|
|
|
|
def reset_timings(self):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_reset_timings(self.ctx)
|
|
|
|
def print_timings(self):
|
|
assert self.ctx is not None
|
|
llama_cpp.llama_print_timings(self.ctx)
|
|
|
|
# Utility functions
|
|
@staticmethod
|
|
def default_params():
|
|
"""Get the default llama_context_params."""
|
|
return llama_cpp.llama_context_default_params()
|
|
|
|
|
|
class _LlamaBatch:
|
|
def __init__(
|
|
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
|
):
|
|
self._n_tokens = n_tokens
|
|
self.embd = embd
|
|
self.n_seq_max = n_seq_max
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
self.batch = None
|
|
self.batch = llama_cpp.llama_batch_init(
|
|
self._n_tokens, self.embd, self.n_seq_max
|
|
)
|
|
|
|
def free_batch():
|
|
if self.batch is None:
|
|
return
|
|
llama_cpp.llama_batch_free(self.batch)
|
|
self.batch = None
|
|
|
|
self._exit_stack.callback(free_batch)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def n_tokens(self) -> int:
|
|
assert self.batch is not None
|
|
return self.batch.n_tokens
|
|
|
|
def reset(self):
|
|
assert self.batch is not None
|
|
self.batch.n_tokens = 0
|
|
|
|
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
|
assert self.batch is not None
|
|
n_tokens = len(batch)
|
|
self.batch.n_tokens = n_tokens
|
|
for i in range(n_tokens):
|
|
self.batch.token[i] = batch[i]
|
|
self.batch.pos[i] = n_past + i
|
|
self.batch.seq_id[i][0] = 0
|
|
self.batch.n_seq_id[i] = 1
|
|
self.batch.logits[i] = logits_all
|
|
self.batch.logits[n_tokens - 1] = True
|
|
|
|
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
|
assert self.batch is not None
|
|
n_tokens = len(batch)
|
|
n_tokens0 = self.batch.n_tokens
|
|
self.batch.n_tokens += n_tokens
|
|
for i in range(n_tokens):
|
|
j = n_tokens0 + i
|
|
self.batch.token[j] = batch[i]
|
|
self.batch.pos[j] = i
|
|
self.batch.seq_id[j][0] = seq_id
|
|
self.batch.n_seq_id[j] = 1
|
|
self.batch.logits[j] = logits_all
|
|
self.batch.logits[n_tokens - 1] = True
|
|
|
|
|
|
class _LlamaTokenDataArray:
|
|
def __init__(self, *, n_vocab: int):
|
|
self.n_vocab = n_vocab
|
|
self.candidates_data = np.recarray(
|
|
(self.n_vocab,),
|
|
dtype=np.dtype(
|
|
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
|
),
|
|
)
|
|
self.candidates = llama_cpp.llama_token_data_array(
|
|
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
|
size=self.n_vocab,
|
|
sorted=False,
|
|
)
|
|
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
|
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
|
|
|
def copy_logits(self, logits: npt.NDArray[np.single]):
|
|
self.candidates_data.id[:] = self.default_candidates_data_id
|
|
self.candidates_data.logit[:] = logits
|
|
self.candidates_data.p[:] = self.default_candidates_data_p
|
|
self.candidates.sorted = False
|
|
self.candidates.size = self.n_vocab
|
|
|
|
|
|
# Python wrappers over common/common
|
|
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
|
|
assert model.model is not None
|
|
n_tokens = len(text) + 1 if add_bos else len(text)
|
|
result = (llama_cpp.llama_token * n_tokens)()
|
|
n_tokens = llama_cpp.llama_tokenize(
|
|
model.model,
|
|
text.encode("utf-8"),
|
|
len(text),
|
|
result,
|
|
n_tokens,
|
|
add_bos,
|
|
special,
|
|
)
|
|
if n_tokens < 0:
|
|
result = (llama_cpp.llama_token * -n_tokens)()
|
|
check = llama_cpp.llama_tokenize(
|
|
model.model,
|
|
text.encode("utf-8"),
|
|
len(text),
|
|
result,
|
|
len(result),
|
|
add_bos,
|
|
special,
|
|
)
|
|
if check != -n_tokens:
|
|
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
|
|
else:
|
|
result = result[:n_tokens]
|
|
return list(result)
|
|
|
|
|
|
def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str:
|
|
assert model.model is not None
|
|
result = (ctypes.c_char * 8)(0)
|
|
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
|
if n_tokens < 0:
|
|
result = (ctypes.c_char * -n_tokens)(0)
|
|
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
|
|
if check != -n_tokens:
|
|
raise RuntimeError(f"Failed to get piece: token={token}")
|
|
else:
|
|
result = result[:n_tokens]
|
|
return bytes(result).decode("utf-8")
|
|
|
|
|
|
def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str:
|
|
bos_id = model.token_bos()
|
|
result = ""
|
|
for i, token in enumerate(tokens):
|
|
piece = _token_to_piece(model, token)
|
|
if (
|
|
(tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0)
|
|
) and piece[0] == " ":
|
|
piece = piece[1:]
|
|
result += piece
|
|
return result
|
|
|
|
|
|
def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str:
|
|
result = ""
|
|
for token in tokens:
|
|
piece = _token_to_piece(model, token)
|
|
result += piece
|
|
return result
|
|
|
|
|
|
def _should_add_bos(model: _LlamaModel) -> bool:
|
|
assert model.model is not None
|
|
add_bos = llama_cpp.llama_add_bos_token(model.model)
|
|
if add_bos != -1:
|
|
return add_bos != 0
|
|
else:
|
|
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
|
|
|
|
|
|
# Embedding functions
|
|
|
|
|
|
def _normalize_embedding(embedding):
|
|
norm = float(np.linalg.norm(embedding))
|
|
if norm == 0.0:
|
|
return embedding
|
|
return [v / norm for v in embedding]
|
|
|
|
|
|
# Python wrappers over common/sampling structs
|
|
|
|
|
|
@dataclass
|
|
class _LlamaSamplingParams:
|
|
n_prev: int = 64
|
|
n_probs: int = 0
|
|
top_k: int = 40
|
|
top_p: float = 0.95
|
|
min_p: float = 0.05
|
|
tfs_z: float = 1.00
|
|
typical_p: float = 1.00
|
|
temp: float = 0.80
|
|
penalty_last_n: int = 64
|
|
penalty_repeat: float = 1.10
|
|
penalty_freq: float = 0.00
|
|
penalty_present: float = 0.00
|
|
mirostat: int = 0
|
|
mirostat_tau: float = 5.00
|
|
mirostat_eta: float = 0.10
|
|
penalize_nl: bool = True
|
|
|
|
grammar: str = ""
|
|
|
|
cfg_negative_prompt: str = ""
|
|
cfg_scale: float = 1.00
|
|
|
|
logit_bias: dict[int, float] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class _LlamaSamplingContext:
|
|
params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams)
|
|
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
|
|
grammar: Optional[LlamaGrammar] = None
|
|
# NOTE: Missing parsed_grammar
|
|
prev: list[int] = field(default_factory=list)
|
|
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
|
|
|
|
def reset(self):
|
|
self.prev = []
|
|
self.cur = []
|
|
if self.grammar is not None:
|
|
self.grammar.reset()
|
|
|
|
def cp(self):
|
|
return _LlamaSamplingContext(
|
|
params=self.params,
|
|
mirostat_mu=self.mirostat_mu,
|
|
grammar=self.grammar,
|
|
prev=self.prev.copy(),
|
|
cur=self.cur.copy(),
|
|
)
|
|
|
|
def last(self) -> Optional[int]:
|
|
if len(self.prev) > 0:
|
|
return self.prev[-1]
|
|
else:
|
|
return None
|
|
|
|
def prev_str(self, ctx_main: _LlamaContext, n: int) -> str:
|
|
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
|
|
|
def sample(
|
|
self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
|
|
):
|
|
n_vocab = ctx_main.model.n_vocab()
|
|
id: int = 0
|
|
|
|
if logits_array is None:
|
|
logits = ctx_main.get_logits_ith(idx)
|
|
logits_array = np.array(
|
|
ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents,
|
|
dtype=np.single,
|
|
)
|
|
|
|
# apply logit_bias
|
|
for token, logit_bias in self.params.logit_bias.items():
|
|
logits_array[token] += logit_bias
|
|
|
|
token_data_array = _LlamaTokenDataArray(
|
|
n_vocab=n_vocab
|
|
) # TODO: Only create this once
|
|
token_data_array.copy_logits(logits_array)
|
|
|
|
# apply penalties
|
|
if len(self.prev) > 0:
|
|
nl_token = ctx_main.model.token_nl()
|
|
nl_logit = logits_array[nl_token]
|
|
last_tokens = self.prev[-self.params.penalty_last_n:]
|
|
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
|
|
if last_tokens_size > 0:
|
|
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
|
|
ctx_main.sample_repetition_penalties(
|
|
token_data_array,
|
|
last_tokens_p,
|
|
last_tokens_size,
|
|
self.params.penalty_repeat,
|
|
self.params.penalty_freq,
|
|
self.params.penalty_present,
|
|
)
|
|
if not self.params.penalize_nl:
|
|
token_data_array.candidates_data.logit[nl_token] = nl_logit
|
|
|
|
if self.grammar is not None:
|
|
ctx_main.sample_grammar(token_data_array, self.grammar)
|
|
|
|
if self.params.temp < 0:
|
|
ctx_main.sample_softmax(token_data_array)
|
|
id = token_data_array.candidates_data.id[0]
|
|
elif self.params.temp == 0:
|
|
id = ctx_main.sample_token_greedy(token_data_array)
|
|
else:
|
|
if self.params.mirostat == 1:
|
|
mirostat_m = 100
|
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
|
id = ctx_main.sample_token_mirostat(
|
|
token_data_array,
|
|
self.params.mirostat_tau,
|
|
self.params.mirostat_eta,
|
|
mirostat_m,
|
|
ctypes.pointer(self.mirostat_mu),
|
|
)
|
|
elif self.params.mirostat == 2:
|
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
|
id = ctx_main.sample_token_mirostat_v2(
|
|
token_data_array,
|
|
self.params.mirostat_tau,
|
|
self.params.mirostat_eta,
|
|
ctypes.pointer(self.mirostat_mu),
|
|
)
|
|
else:
|
|
min_keep = max(1, self.params.n_probs)
|
|
ctx_main.sample_top_k(
|
|
token_data_array, self.params.top_k, min_keep=min_keep
|
|
)
|
|
ctx_main.sample_tail_free(
|
|
token_data_array, self.params.tfs_z, min_keep=min_keep
|
|
)
|
|
ctx_main.sample_typical(
|
|
token_data_array, self.params.typical_p, min_keep=min_keep
|
|
)
|
|
ctx_main.sample_top_p(
|
|
token_data_array, self.params.top_p, min_keep=min_keep
|
|
)
|
|
ctx_main.sample_min_p(
|
|
token_data_array, self.params.min_p, min_keep=min_keep
|
|
)
|
|
ctx_main.sample_temp(token_data_array, self.params.temp)
|
|
id = ctx_main.sample_token(token_data_array)
|
|
return id
|
|
|
|
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
|
|
if apply_grammar and self.grammar is not None:
|
|
ctx_main.grammar_accept_token(self.grammar, id)
|
|
self.prev.append(id)
|