misc: additional type annotations for low level api
This commit is contained in:
parent
3632241e98
commit
aefcb8f71a
1 changed files with 227 additions and 125 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
|
@ -6,7 +8,16 @@ from ctypes import (
|
||||||
Array,
|
Array,
|
||||||
)
|
)
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Union, NewType, Optional
|
from typing import (
|
||||||
|
List,
|
||||||
|
Union,
|
||||||
|
NewType,
|
||||||
|
Optional,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
TypeVar,
|
||||||
|
TypeAlias,
|
||||||
|
Generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Load the library
|
# Load the library
|
||||||
|
@ -71,14 +82,39 @@ _lib_base_name = "llama"
|
||||||
# Load the library
|
# Load the library
|
||||||
_lib = _load_shared_library(_lib_base_name)
|
_lib = _load_shared_library(_lib_base_name)
|
||||||
|
|
||||||
# Misc
|
|
||||||
c_float_p = ctypes.POINTER(ctypes.c_float)
|
# ctypes sane type hint helpers
|
||||||
c_uint8_p = ctypes.POINTER(ctypes.c_uint8)
|
#
|
||||||
c_size_t_p = ctypes.POINTER(ctypes.c_size_t)
|
# - Generic Pointer and Array types
|
||||||
|
# - PointerOrRef type with a type hinted byref function
|
||||||
|
#
|
||||||
|
# NOTE: Only use these for static type checking not for runtime checks
|
||||||
|
# no good will come of that
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
|
||||||
|
|
||||||
|
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
|
||||||
|
|
||||||
|
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
|
||||||
|
|
||||||
|
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
|
||||||
|
|
||||||
|
class CtypesRef(Generic[CtypesCData]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
CtypesPointerOrRef: TypeAlias = Union[
|
||||||
|
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
|
||||||
|
]
|
||||||
|
|
||||||
|
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# from ggml-backend.h
|
# from ggml-backend.h
|
||||||
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||||
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p)
|
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
|
||||||
|
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
|
||||||
|
)
|
||||||
|
|
||||||
# llama.h bindings
|
# llama.h bindings
|
||||||
|
|
||||||
|
@ -286,7 +322,9 @@ class llama_token_data_array(ctypes.Structure):
|
||||||
llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
|
llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
|
||||||
|
|
||||||
# typedef bool (*llama_progress_callback)(float progress, void *ctx);
|
# typedef bool (*llama_progress_callback)(float progress, void *ctx);
|
||||||
llama_progress_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_float, ctypes.c_void_p)
|
llama_progress_callback = ctypes.CFUNCTYPE(
|
||||||
|
ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# // Input data for llama_decode
|
# // Input data for llama_decode
|
||||||
|
@ -336,7 +374,7 @@ class llama_batch(ctypes.Structure):
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("n_tokens", ctypes.c_int32),
|
("n_tokens", ctypes.c_int32),
|
||||||
("token", ctypes.POINTER(llama_token)),
|
("token", ctypes.POINTER(llama_token)),
|
||||||
("embd", c_float_p),
|
("embd", ctypes.POINTER(ctypes.c_float)),
|
||||||
("pos", ctypes.POINTER(llama_pos)),
|
("pos", ctypes.POINTER(llama_pos)),
|
||||||
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
|
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
|
||||||
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
|
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
|
||||||
|
@ -431,7 +469,7 @@ class llama_model_params(ctypes.Structure):
|
||||||
("n_gpu_layers", ctypes.c_int32),
|
("n_gpu_layers", ctypes.c_int32),
|
||||||
("split_mode", ctypes.c_int),
|
("split_mode", ctypes.c_int),
|
||||||
("main_gpu", ctypes.c_int32),
|
("main_gpu", ctypes.c_int32),
|
||||||
("tensor_split", c_float_p),
|
("tensor_split", ctypes.POINTER(ctypes.c_float)),
|
||||||
("progress_callback", llama_progress_callback),
|
("progress_callback", llama_progress_callback),
|
||||||
("progress_callback_user_data", ctypes.c_void_p),
|
("progress_callback_user_data", ctypes.c_void_p),
|
||||||
("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
|
("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
|
||||||
|
@ -532,7 +570,9 @@ class llama_context_params(ctypes.Structure):
|
||||||
# // if it exists.
|
# // if it exists.
|
||||||
# // It might not exist for progress report where '.' is output repeatedly.
|
# // It might not exist for progress report where '.' is output repeatedly.
|
||||||
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||||
llama_log_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
llama_log_callback = ctypes.CFUNCTYPE(
|
||||||
|
None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
|
||||||
|
)
|
||||||
"""Signature for logging events
|
"""Signature for logging events
|
||||||
Note that text includes the new line character at the end for most events.
|
Note that text includes the new line character at the end for most events.
|
||||||
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||||
|
@ -966,14 +1006,23 @@ llama_rope_freq_scale_train.restype = ctypes.c_float
|
||||||
# // Get metadata value as a string by key name
|
# // Get metadata value as a string by key name
|
||||||
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
|
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
|
||||||
def llama_model_meta_val_str(
|
def llama_model_meta_val_str(
|
||||||
model: llama_model_p, key: Union[ctypes.c_char_p, bytes], buf: bytes, buf_size: int, /
|
model: llama_model_p,
|
||||||
|
key: Union[ctypes.c_char_p, bytes],
|
||||||
|
buf: bytes,
|
||||||
|
buf_size: int,
|
||||||
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Get metadata value as a string by key name"""
|
"""Get metadata value as a string by key name"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
llama_model_meta_val_str = _lib.llama_model_meta_val_str
|
llama_model_meta_val_str = _lib.llama_model_meta_val_str
|
||||||
llama_model_meta_val_str.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t]
|
llama_model_meta_val_str.argtypes = [
|
||||||
|
llama_model_p_ctypes,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
]
|
||||||
llama_model_meta_val_str.restype = ctypes.c_int32
|
llama_model_meta_val_str.restype = ctypes.c_int32
|
||||||
|
|
||||||
|
|
||||||
|
@ -1087,8 +1136,8 @@ llama_get_model_tensor.restype = ctypes.c_void_p
|
||||||
def llama_model_quantize(
|
def llama_model_quantize(
|
||||||
fname_inp: bytes,
|
fname_inp: bytes,
|
||||||
fname_out: bytes,
|
fname_out: bytes,
|
||||||
params, # type: ctypes.POINTER(llama_model_quantize_params) # type: ignore
|
params: CtypesPointerOrRef[llama_model_quantize_params],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Returns 0 on success"""
|
"""Returns 0 on success"""
|
||||||
...
|
...
|
||||||
|
@ -1121,8 +1170,8 @@ def llama_apply_lora_from_file(
|
||||||
path_lora: Union[ctypes.c_char_p, bytes],
|
path_lora: Union[ctypes.c_char_p, bytes],
|
||||||
scale: Union[ctypes.c_float, float],
|
scale: Union[ctypes.c_float, float],
|
||||||
path_base_model: Union[ctypes.c_char_p, bytes],
|
path_base_model: Union[ctypes.c_char_p, bytes],
|
||||||
n_threads: Union[ctypes.c_int, int],
|
n_threads: Union[ctypes.c_int32, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Apply a LoRA adapter to a loaded model
|
"""Apply a LoRA adapter to a loaded model
|
||||||
path_base_model is the path to a higher quality model to use as a base for
|
path_base_model is the path to a higher quality model to use as a base for
|
||||||
|
@ -1155,8 +1204,8 @@ def llama_model_apply_lora_from_file(
|
||||||
path_lora: Union[ctypes.c_char_p, bytes],
|
path_lora: Union[ctypes.c_char_p, bytes],
|
||||||
scale: Union[ctypes.c_float, float],
|
scale: Union[ctypes.c_float, float],
|
||||||
path_base_model: Union[ctypes.c_char_p, bytes],
|
path_base_model: Union[ctypes.c_char_p, bytes],
|
||||||
n_threads: Union[ctypes.c_int, int],
|
n_threads: Union[ctypes.c_int32, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1262,7 +1311,7 @@ llama_kv_cache_view_free.restype = None
|
||||||
|
|
||||||
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||||
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||||
def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore
|
def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore
|
||||||
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
|
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1326,7 +1375,7 @@ def llama_kv_cache_seq_rm(
|
||||||
seq_id: Union[llama_seq_id, int],
|
seq_id: Union[llama_seq_id, int],
|
||||||
p0: Union[llama_pos, int],
|
p0: Union[llama_pos, int],
|
||||||
p1: Union[llama_pos, int],
|
p1: Union[llama_pos, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
seq_id < 0 : match any sequence
|
seq_id < 0 : match any sequence
|
||||||
|
@ -1361,7 +1410,7 @@ def llama_kv_cache_seq_cp(
|
||||||
seq_id_dst: Union[llama_seq_id, int],
|
seq_id_dst: Union[llama_seq_id, int],
|
||||||
p0: Union[llama_pos, int],
|
p0: Union[llama_pos, int],
|
||||||
p1: Union[llama_pos, int],
|
p1: Union[llama_pos, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Copy all tokens that belong to the specified sequence to another sequence
|
"""Copy all tokens that belong to the specified sequence to another sequence
|
||||||
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||||
|
@ -1385,11 +1434,7 @@ llama_kv_cache_seq_cp.restype = None
|
||||||
# LLAMA_API void llama_kv_cache_seq_keep(
|
# LLAMA_API void llama_kv_cache_seq_keep(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# llama_seq_id seq_id);
|
# llama_seq_id seq_id);
|
||||||
def llama_kv_cache_seq_keep(
|
def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /):
|
||||||
ctx: llama_context_p,
|
|
||||||
seq_id: Union[llama_seq_id, int],
|
|
||||||
/
|
|
||||||
):
|
|
||||||
"""Removes all tokens that do not belong to the specified sequence"""
|
"""Removes all tokens that do not belong to the specified sequence"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1415,7 +1460,7 @@ def llama_kv_cache_seq_shift(
|
||||||
p0: Union[llama_pos, int],
|
p0: Union[llama_pos, int],
|
||||||
p1: Union[llama_pos, int],
|
p1: Union[llama_pos, int],
|
||||||
delta: Union[llama_pos, int],
|
delta: Union[llama_pos, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
If the KV cache is RoPEd, the KV data is updated accordingly
|
If the KV cache is RoPEd, the KV data is updated accordingly
|
||||||
|
@ -1451,7 +1496,7 @@ def llama_kv_cache_seq_div(
|
||||||
p0: Union[llama_pos, int],
|
p0: Union[llama_pos, int],
|
||||||
p1: Union[llama_pos, int],
|
p1: Union[llama_pos, int],
|
||||||
d: Union[ctypes.c_int, int],
|
d: Union[ctypes.c_int, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Integer division of the positions by factor of `d > 1`
|
"""Integer division of the positions by factor of `d > 1`
|
||||||
If the KV cache is RoPEd, the KV data is updated accordingly
|
If the KV cache is RoPEd, the KV data is updated accordingly
|
||||||
|
@ -1496,8 +1541,7 @@ llama_get_state_size.restype = ctypes.c_size_t
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# uint8_t * dst);
|
# uint8_t * dst);
|
||||||
def llama_copy_state_data(
|
def llama_copy_state_data(
|
||||||
ctx: llama_context_p, dst, # type: Array[ctypes.c_uint8]
|
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
|
||||||
/
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Copies the state to the specified destination address.
|
"""Copies the state to the specified destination address.
|
||||||
Destination needs to have allocated enough memory.
|
Destination needs to have allocated enough memory.
|
||||||
|
@ -1506,7 +1550,10 @@ def llama_copy_state_data(
|
||||||
|
|
||||||
|
|
||||||
llama_copy_state_data = _lib.llama_copy_state_data
|
llama_copy_state_data = _lib.llama_copy_state_data
|
||||||
llama_copy_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
|
llama_copy_state_data.argtypes = [
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.POINTER(ctypes.c_uint8),
|
||||||
|
]
|
||||||
llama_copy_state_data.restype = ctypes.c_size_t
|
llama_copy_state_data.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
|
@ -1516,15 +1563,14 @@ llama_copy_state_data.restype = ctypes.c_size_t
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# uint8_t * src);
|
# uint8_t * src);
|
||||||
def llama_set_state_data(
|
def llama_set_state_data(
|
||||||
ctx: llama_context_p, src, # type: Array[ctypes.c_uint8]
|
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
|
||||||
/
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Set the state reading from the specified address"""
|
"""Set the state reading from the specified address"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
llama_set_state_data = _lib.llama_set_state_data
|
llama_set_state_data = _lib.llama_set_state_data
|
||||||
llama_set_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
|
llama_set_state_data.argtypes = [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)]
|
||||||
llama_set_state_data.restype = ctypes.c_size_t
|
llama_set_state_data.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
|
@ -1538,10 +1584,10 @@ llama_set_state_data.restype = ctypes.c_size_t
|
||||||
def llama_load_session_file(
|
def llama_load_session_file(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
path_session: bytes,
|
path_session: bytes,
|
||||||
tokens_out, # type: Array[llama_token]
|
tokens_out: CtypesArray[llama_token],
|
||||||
n_token_capacity: Union[ctypes.c_size_t, int],
|
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||||
n_token_count_out, # type: _Pointer[ctypes.c_size_t]
|
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1552,7 +1598,7 @@ llama_load_session_file.argtypes = [
|
||||||
ctypes.c_char_p,
|
ctypes.c_char_p,
|
||||||
llama_token_p,
|
llama_token_p,
|
||||||
ctypes.c_size_t,
|
ctypes.c_size_t,
|
||||||
c_size_t_p,
|
ctypes.POINTER(ctypes.c_size_t),
|
||||||
]
|
]
|
||||||
llama_load_session_file.restype = ctypes.c_size_t
|
llama_load_session_file.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
@ -1565,9 +1611,9 @@ llama_load_session_file.restype = ctypes.c_size_t
|
||||||
def llama_save_session_file(
|
def llama_save_session_file(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
path_session: bytes,
|
path_session: bytes,
|
||||||
tokens, # type: Array[llama_token]
|
tokens: CtypesArray[llama_token],
|
||||||
n_token_count: Union[ctypes.c_size_t, int],
|
n_token_count: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1599,10 +1645,10 @@ llama_save_session_file.restype = ctypes.c_size_t
|
||||||
# "use llama_decode() instead");
|
# "use llama_decode() instead");
|
||||||
def llama_eval(
|
def llama_eval(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
tokens, # type: Array[llama_token]
|
tokens: CtypesArray[llama_token],
|
||||||
n_tokens: Union[ctypes.c_int, int],
|
n_tokens: Union[ctypes.c_int, int],
|
||||||
n_past: Union[ctypes.c_int, int],
|
n_past: Union[ctypes.c_int, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
|
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
|
||||||
tokens + n_tokens is the provided batch of new tokens to process
|
tokens + n_tokens is the provided batch of new tokens to process
|
||||||
|
@ -1613,7 +1659,12 @@ def llama_eval(
|
||||||
|
|
||||||
|
|
||||||
llama_eval = _lib.llama_eval
|
llama_eval = _lib.llama_eval
|
||||||
llama_eval.argtypes = [llama_context_p_ctypes, llama_token_p, ctypes.c_int32, ctypes.c_int32]
|
llama_eval.argtypes = [
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
llama_token_p,
|
||||||
|
ctypes.c_int32,
|
||||||
|
ctypes.c_int32,
|
||||||
|
]
|
||||||
llama_eval.restype = ctypes.c_int
|
llama_eval.restype = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
@ -1627,10 +1678,10 @@ llama_eval.restype = ctypes.c_int
|
||||||
# "use llama_decode() instead");
|
# "use llama_decode() instead");
|
||||||
def llama_eval_embd(
|
def llama_eval_embd(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
embd, # type: Array[ctypes.c_float]
|
embd: CtypesArray[ctypes.c_float],
|
||||||
n_tokens: Union[ctypes.c_int, int],
|
n_tokens: Union[ctypes.c_int, int],
|
||||||
n_past: Union[ctypes.c_int, int],
|
n_past: Union[ctypes.c_int, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Same as llama_eval, but use float matrix input directly.
|
"""Same as llama_eval, but use float matrix input directly.
|
||||||
DEPRECATED: use llama_decode() instead"""
|
DEPRECATED: use llama_decode() instead"""
|
||||||
|
@ -1638,7 +1689,12 @@ def llama_eval_embd(
|
||||||
|
|
||||||
|
|
||||||
llama_eval_embd = _lib.llama_eval_embd
|
llama_eval_embd = _lib.llama_eval_embd
|
||||||
llama_eval_embd.argtypes = [llama_context_p_ctypes, c_float_p, ctypes.c_int32, ctypes.c_int32]
|
llama_eval_embd.argtypes = [
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
ctypes.c_int32,
|
||||||
|
ctypes.c_int32,
|
||||||
|
]
|
||||||
llama_eval_embd.restype = ctypes.c_int
|
llama_eval_embd.restype = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
@ -1652,11 +1708,11 @@ llama_eval_embd.restype = ctypes.c_int
|
||||||
# llama_pos pos_0,
|
# llama_pos pos_0,
|
||||||
# llama_seq_id seq_id);
|
# llama_seq_id seq_id);
|
||||||
def llama_batch_get_one(
|
def llama_batch_get_one(
|
||||||
tokens, # type: Array[llama_token]
|
tokens: CtypesArray[llama_token],
|
||||||
n_tokens: Union[ctypes.c_int, int],
|
n_tokens: Union[ctypes.c_int, int],
|
||||||
pos_0: Union[llama_pos, int],
|
pos_0: Union[llama_pos, int],
|
||||||
seq_id: llama_seq_id,
|
seq_id: llama_seq_id,
|
||||||
/
|
/,
|
||||||
) -> llama_batch:
|
) -> llama_batch:
|
||||||
"""Return batch for single sequence of tokens starting at pos_0
|
"""Return batch for single sequence of tokens starting at pos_0
|
||||||
|
|
||||||
|
@ -1690,7 +1746,7 @@ def llama_batch_init(
|
||||||
n_tokens: Union[ctypes.c_int32, int],
|
n_tokens: Union[ctypes.c_int32, int],
|
||||||
embd: Union[ctypes.c_int32, int],
|
embd: Union[ctypes.c_int32, int],
|
||||||
n_seq_max: Union[ctypes.c_int32, int],
|
n_seq_max: Union[ctypes.c_int32, int],
|
||||||
/
|
/,
|
||||||
) -> llama_batch:
|
) -> llama_batch:
|
||||||
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||||
Each token can be assigned up to n_seq_max sequence ids
|
Each token can be assigned up to n_seq_max sequence ids
|
||||||
|
@ -1747,7 +1803,7 @@ def llama_set_n_threads(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
n_threads: Union[ctypes.c_uint32, int],
|
n_threads: Union[ctypes.c_uint32, int],
|
||||||
n_threads_batch: Union[ctypes.c_uint32, int],
|
n_threads_batch: Union[ctypes.c_uint32, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Set the number of threads used for decoding
|
"""Set the number of threads used for decoding
|
||||||
n_threads is the number of threads used for generation (single token)
|
n_threads is the number of threads used for generation (single token)
|
||||||
|
@ -1757,7 +1813,11 @@ def llama_set_n_threads(
|
||||||
|
|
||||||
|
|
||||||
llama_set_n_threads = _lib.llama_set_n_threads
|
llama_set_n_threads = _lib.llama_set_n_threads
|
||||||
llama_set_n_threads.argtypes = [llama_context_p_ctypes, ctypes.c_uint32, ctypes.c_uint32]
|
llama_set_n_threads.argtypes = [
|
||||||
|
llama_context_p_ctypes,
|
||||||
|
ctypes.c_uint32,
|
||||||
|
ctypes.c_uint32,
|
||||||
|
]
|
||||||
llama_set_n_threads.restype = None
|
llama_set_n_threads.restype = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1768,8 +1828,7 @@ llama_set_n_threads.restype = None
|
||||||
# // Cols: n_vocab
|
# // Cols: n_vocab
|
||||||
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
def llama_get_logits(
|
def llama_get_logits(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p, /
|
||||||
/
|
|
||||||
): # type: (...) -> Array[float] # type: ignore
|
): # type: (...) -> Array[float] # type: ignore
|
||||||
"""Token logits obtained from the last call to llama_eval()
|
"""Token logits obtained from the last call to llama_eval()
|
||||||
The logits for the last token are stored in the last row
|
The logits for the last token are stored in the last row
|
||||||
|
@ -1781,15 +1840,14 @@ def llama_get_logits(
|
||||||
|
|
||||||
llama_get_logits = _lib.llama_get_logits
|
llama_get_logits = _lib.llama_get_logits
|
||||||
llama_get_logits.argtypes = [llama_context_p_ctypes]
|
llama_get_logits.argtypes = [llama_context_p_ctypes]
|
||||||
llama_get_logits.restype = c_float_p
|
llama_get_logits.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
|
||||||
# // Logits for the ith token. Equivalent to:
|
# // Logits for the ith token. Equivalent to:
|
||||||
# // llama_get_logits(ctx) + i*n_vocab
|
# // llama_get_logits(ctx) + i*n_vocab
|
||||||
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
def llama_get_logits_ith(
|
def llama_get_logits_ith(
|
||||||
ctx: llama_context_p, i: Union[ctypes.c_int32, int]
|
ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
|
||||||
, /
|
|
||||||
): # type: (...) -> Array[float] # type: ignore
|
): # type: (...) -> Array[float] # type: ignore
|
||||||
"""Logits for the ith token. Equivalent to:
|
"""Logits for the ith token. Equivalent to:
|
||||||
llama_get_logits(ctx) + i*n_vocab"""
|
llama_get_logits(ctx) + i*n_vocab"""
|
||||||
|
@ -1798,7 +1856,7 @@ def llama_get_logits_ith(
|
||||||
|
|
||||||
llama_get_logits_ith = _lib.llama_get_logits_ith
|
llama_get_logits_ith = _lib.llama_get_logits_ith
|
||||||
llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
||||||
llama_get_logits_ith.restype = c_float_p
|
llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
|
||||||
# Get the embeddings for the input
|
# Get the embeddings for the input
|
||||||
|
@ -1814,7 +1872,7 @@ def llama_get_embeddings(
|
||||||
|
|
||||||
llama_get_embeddings = _lib.llama_get_embeddings
|
llama_get_embeddings = _lib.llama_get_embeddings
|
||||||
llama_get_embeddings.argtypes = [llama_context_p_ctypes]
|
llama_get_embeddings.argtypes = [llama_context_p_ctypes]
|
||||||
llama_get_embeddings.restype = c_float_p
|
llama_get_embeddings.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
|
||||||
# // Get the embeddings for the ith sequence
|
# // Get the embeddings for the ith sequence
|
||||||
|
@ -1830,7 +1888,7 @@ def llama_get_embeddings_ith(
|
||||||
|
|
||||||
llama_get_embeddings_ith = _lib.llama_get_embeddings_ith
|
llama_get_embeddings_ith = _lib.llama_get_embeddings_ith
|
||||||
llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
||||||
llama_get_embeddings_ith.restype = c_float_p
|
llama_get_embeddings_ith.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
|
||||||
# //
|
# //
|
||||||
|
@ -1839,7 +1897,9 @@ llama_get_embeddings_ith.restype = c_float_p
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
|
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
|
||||||
def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int], /) -> bytes:
|
def llama_token_get_text(
|
||||||
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
|
) -> bytes:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -1861,7 +1921,9 @@ llama_token_get_score.restype = ctypes.c_float
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
||||||
def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int], /) -> int:
|
def llama_token_get_type(
|
||||||
|
model: llama_model_p, token: Union[llama_token, int], /
|
||||||
|
) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -1995,11 +2057,11 @@ def llama_tokenize(
|
||||||
model: llama_model_p,
|
model: llama_model_p,
|
||||||
text: bytes,
|
text: bytes,
|
||||||
text_len: Union[ctypes.c_int, int],
|
text_len: Union[ctypes.c_int, int],
|
||||||
tokens, # type: Array[llama_token]
|
tokens: CtypesArray[llama_token],
|
||||||
n_max_tokens: Union[ctypes.c_int, int],
|
n_max_tokens: Union[ctypes.c_int, int],
|
||||||
add_bos: Union[ctypes.c_bool, bool],
|
add_bos: Union[ctypes.c_bool, bool],
|
||||||
special: Union[ctypes.c_bool, bool],
|
special: Union[ctypes.c_bool, bool],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Convert the provided text into tokens."""
|
"""Convert the provided text into tokens."""
|
||||||
...
|
...
|
||||||
|
@ -2032,7 +2094,7 @@ def llama_token_to_piece(
|
||||||
token: Union[llama_token, int],
|
token: Union[llama_token, int],
|
||||||
buf: Union[ctypes.c_char_p, bytes],
|
buf: Union[ctypes.c_char_p, bytes],
|
||||||
length: Union[ctypes.c_int, int],
|
length: Union[ctypes.c_int, int],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Token Id -> Piece.
|
"""Token Id -> Piece.
|
||||||
Uses the vocabulary in the provided context.
|
Uses the vocabulary in the provided context.
|
||||||
|
@ -2043,7 +2105,12 @@ def llama_token_to_piece(
|
||||||
|
|
||||||
|
|
||||||
llama_token_to_piece = _lib.llama_token_to_piece
|
llama_token_to_piece = _lib.llama_token_to_piece
|
||||||
llama_token_to_piece.argtypes = [llama_model_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32]
|
llama_token_to_piece.argtypes = [
|
||||||
|
llama_model_p_ctypes,
|
||||||
|
llama_token,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_int32,
|
||||||
|
]
|
||||||
llama_token_to_piece.restype = ctypes.c_int32
|
llama_token_to_piece.restype = ctypes.c_int32
|
||||||
|
|
||||||
|
|
||||||
|
@ -2068,23 +2135,23 @@ llama_token_to_piece.restype = ctypes.c_int32
|
||||||
def llama_chat_apply_template(
|
def llama_chat_apply_template(
|
||||||
model: llama_model_p,
|
model: llama_model_p,
|
||||||
tmpl: bytes,
|
tmpl: bytes,
|
||||||
chat: "ctypes._Pointer[llama_chat_message]", # type: ignore
|
chat: CtypesArray[llama_chat_message],
|
||||||
n_msg: int,
|
n_msg: int,
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
llama_chat_apply_template = _lib.llama_chat_apply_template
|
llama_chat_apply_template = _lib.llama_chat_apply_template
|
||||||
llama_chat_apply_template.argtypes = [
|
llama_chat_apply_template.argtypes = [
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
ctypes.c_char_p,
|
ctypes.c_char_p,
|
||||||
ctypes.POINTER(llama_chat_message),
|
ctypes.POINTER(llama_chat_message),
|
||||||
ctypes.c_size_t
|
ctypes.c_size_t,
|
||||||
]
|
]
|
||||||
llama_chat_apply_template.restype = ctypes.c_int32
|
llama_chat_apply_template.restype = ctypes.c_int32
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# //
|
# //
|
||||||
# // Grammar
|
# // Grammar
|
||||||
# //
|
# //
|
||||||
|
@ -2095,10 +2162,12 @@ llama_chat_apply_template.restype = ctypes.c_int32
|
||||||
# size_t n_rules,
|
# size_t n_rules,
|
||||||
# size_t start_rule_index);
|
# size_t start_rule_index);
|
||||||
def llama_grammar_init(
|
def llama_grammar_init(
|
||||||
rules, # type: Array[llama_grammar_element_p] # type: ignore
|
rules: CtypesArray[
|
||||||
|
CtypesPointer[llama_grammar_element]
|
||||||
|
], # NOTE: This might be wrong type sig
|
||||||
n_rules: Union[ctypes.c_size_t, int],
|
n_rules: Union[ctypes.c_size_t, int],
|
||||||
start_rule_index: Union[ctypes.c_size_t, int],
|
start_rule_index: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
) -> llama_grammar_p:
|
) -> llama_grammar_p:
|
||||||
"""Initialize a grammar from a set of rules."""
|
"""Initialize a grammar from a set of rules."""
|
||||||
...
|
...
|
||||||
|
@ -2163,13 +2232,15 @@ llama_set_rng_seed.restype = None
|
||||||
# float penalty_present);
|
# float penalty_present);
|
||||||
def llama_sample_repetition_penalties(
|
def llama_sample_repetition_penalties(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
last_tokens_data, # type: Array[llama_token]
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
|
last_tokens_data: CtypesArray[llama_token],
|
||||||
penalty_last_n: Union[ctypes.c_size_t, int],
|
penalty_last_n: Union[ctypes.c_size_t, int],
|
||||||
penalty_repeat: Union[ctypes.c_float, float],
|
penalty_repeat: Union[ctypes.c_float, float],
|
||||||
penalty_freq: Union[ctypes.c_float, float],
|
penalty_freq: Union[ctypes.c_float, float],
|
||||||
penalty_present: Union[ctypes.c_float, float],
|
penalty_present: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
|
@ -2201,10 +2272,10 @@ llama_sample_repetition_penalties.restype = None
|
||||||
# float scale);
|
# float scale);
|
||||||
def llama_sample_apply_guidance(
|
def llama_sample_apply_guidance(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
logits, # type: _Pointer[ctypes.c_float]
|
logits: CtypesArray[ctypes.c_float],
|
||||||
logits_guidance, # type: _Pointer[ctypes.c_float]
|
logits_guidance: CtypesArray[ctypes.c_float],
|
||||||
scale: Union[ctypes.c_float, float],
|
scale: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
||||||
...
|
...
|
||||||
|
@ -2213,8 +2284,8 @@ def llama_sample_apply_guidance(
|
||||||
llama_sample_apply_guidance = _lib.llama_sample_apply_guidance
|
llama_sample_apply_guidance = _lib.llama_sample_apply_guidance
|
||||||
llama_sample_apply_guidance.argtypes = [
|
llama_sample_apply_guidance.argtypes = [
|
||||||
llama_context_p_ctypes,
|
llama_context_p_ctypes,
|
||||||
c_float_p,
|
ctypes.POINTER(ctypes.c_float),
|
||||||
c_float_p,
|
ctypes.POINTER(ctypes.c_float),
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
]
|
]
|
||||||
llama_sample_apply_guidance.restype = None
|
llama_sample_apply_guidance.restype = None
|
||||||
|
@ -2228,10 +2299,12 @@ llama_sample_apply_guidance.restype = None
|
||||||
# "use llama_sample_apply_guidance() instead");
|
# "use llama_sample_apply_guidance() instead");
|
||||||
def llama_sample_classifier_free_guidance(
|
def llama_sample_classifier_free_guidance(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
guidance_ctx: llama_context_p,
|
guidance_ctx: llama_context_p,
|
||||||
scale: Union[ctypes.c_float, float],
|
scale: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
||||||
...
|
...
|
||||||
|
@ -2252,8 +2325,11 @@ llama_sample_classifier_free_guidance.restype = None
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
# llama_token_data_array * candidates);
|
# llama_token_data_array * candidates);
|
||||||
def llama_sample_softmax(
|
def llama_sample_softmax(
|
||||||
ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data]
|
ctx: llama_context_p,
|
||||||
/
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
|
/,
|
||||||
):
|
):
|
||||||
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
|
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
|
||||||
...
|
...
|
||||||
|
@ -2275,10 +2351,12 @@ llama_sample_softmax.restype = None
|
||||||
# size_t min_keep);
|
# size_t min_keep);
|
||||||
def llama_sample_top_k(
|
def llama_sample_top_k(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
k: Union[ctypes.c_int, int],
|
k: Union[ctypes.c_int, int],
|
||||||
min_keep: Union[ctypes.c_size_t, int],
|
min_keep: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
||||||
...
|
...
|
||||||
|
@ -2302,10 +2380,12 @@ llama_sample_top_k.restype = None
|
||||||
# size_t min_keep);
|
# size_t min_keep);
|
||||||
def llama_sample_top_p(
|
def llama_sample_top_p(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
p: Union[ctypes.c_float, float],
|
p: Union[ctypes.c_float, float],
|
||||||
min_keep: Union[ctypes.c_size_t, int],
|
min_keep: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
||||||
...
|
...
|
||||||
|
@ -2329,10 +2409,12 @@ llama_sample_top_p.restype = None
|
||||||
# size_t min_keep);
|
# size_t min_keep);
|
||||||
def llama_sample_min_p(
|
def llama_sample_min_p(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
p: Union[ctypes.c_float, float],
|
p: Union[ctypes.c_float, float],
|
||||||
min_keep: Union[ctypes.c_size_t, int],
|
min_keep: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
|
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
|
||||||
...
|
...
|
||||||
|
@ -2356,10 +2438,12 @@ llama_sample_min_p.restype = None
|
||||||
# size_t min_keep);
|
# size_t min_keep);
|
||||||
def llama_sample_tail_free(
|
def llama_sample_tail_free(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
z: Union[ctypes.c_float, float],
|
z: Union[ctypes.c_float, float],
|
||||||
min_keep: Union[ctypes.c_size_t, int],
|
min_keep: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
|
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
|
||||||
...
|
...
|
||||||
|
@ -2383,10 +2467,12 @@ llama_sample_tail_free.restype = None
|
||||||
# size_t min_keep);
|
# size_t min_keep);
|
||||||
def llama_sample_typical(
|
def llama_sample_typical(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
p: Union[ctypes.c_float, float],
|
p: Union[ctypes.c_float, float],
|
||||||
min_keep: Union[ctypes.c_size_t, int],
|
min_keep: Union[ctypes.c_size_t, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
|
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
|
||||||
...
|
...
|
||||||
|
@ -2411,11 +2497,13 @@ llama_sample_typical.restype = None
|
||||||
# float exponent_val);
|
# float exponent_val);
|
||||||
def llama_sample_entropy(
|
def llama_sample_entropy(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
min_temp: Union[ctypes.c_float, float],
|
min_temp: Union[ctypes.c_float, float],
|
||||||
max_temp: Union[ctypes.c_float, float],
|
max_temp: Union[ctypes.c_float, float],
|
||||||
exponent_val: Union[ctypes.c_float, float],
|
exponent_val: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
|
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
|
||||||
...
|
...
|
||||||
|
@ -2438,9 +2526,11 @@ llama_sample_entropy.restype = None
|
||||||
# float temp);
|
# float temp);
|
||||||
def llama_sample_temp(
|
def llama_sample_temp(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
temp: Union[ctypes.c_float, float],
|
temp: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
|
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
|
||||||
|
|
||||||
|
@ -2467,9 +2557,11 @@ llama_sample_temp.restype = None
|
||||||
# "use llama_sample_temp instead");
|
# "use llama_sample_temp instead");
|
||||||
def llama_sample_temperature(
|
def llama_sample_temperature(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
temp: Union[ctypes.c_float, float],
|
temp: Union[ctypes.c_float, float],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""use llama_sample_temp instead"""
|
"""use llama_sample_temp instead"""
|
||||||
...
|
...
|
||||||
|
@ -2491,9 +2583,11 @@ llama_sample_temperature.restype = None
|
||||||
# const struct llama_grammar * grammar);
|
# const struct llama_grammar * grammar);
|
||||||
def llama_sample_grammar(
|
def llama_sample_grammar(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
grammar, # type: llama_grammar_p
|
grammar, # type: llama_grammar_p
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
"""Apply constraints from grammar
|
"""Apply constraints from grammar
|
||||||
|
|
||||||
|
@ -2528,12 +2622,14 @@ llama_sample_grammar.restype = None
|
||||||
# float * mu);
|
# float * mu);
|
||||||
def llama_sample_token_mirostat(
|
def llama_sample_token_mirostat(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
tau: Union[ctypes.c_float, float],
|
tau: Union[ctypes.c_float, float],
|
||||||
eta: Union[ctypes.c_float, float],
|
eta: Union[ctypes.c_float, float],
|
||||||
m: Union[ctypes.c_int, int],
|
m: Union[ctypes.c_int, int],
|
||||||
mu, # type: _Pointer[ctypes.c_float]
|
mu: CtypesPointerOrRef[ctypes.c_float],
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
|
|
||||||
|
@ -2554,7 +2650,7 @@ llama_sample_token_mirostat.argtypes = [
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
ctypes.c_int32,
|
ctypes.c_int32,
|
||||||
c_float_p,
|
ctypes.POINTER(ctypes.c_float),
|
||||||
]
|
]
|
||||||
llama_sample_token_mirostat.restype = llama_token
|
llama_sample_token_mirostat.restype = llama_token
|
||||||
|
|
||||||
|
@ -2572,11 +2668,13 @@ llama_sample_token_mirostat.restype = llama_token
|
||||||
# float * mu);
|
# float * mu);
|
||||||
def llama_sample_token_mirostat_v2(
|
def llama_sample_token_mirostat_v2(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
tau: Union[ctypes.c_float, float],
|
tau: Union[ctypes.c_float, float],
|
||||||
eta: Union[ctypes.c_float, float],
|
eta: Union[ctypes.c_float, float],
|
||||||
mu, # type: _Pointer[ctypes.c_float]
|
mu, # type: _Pointer[ctypes.c_float]
|
||||||
/
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
|
|
||||||
|
@ -2595,7 +2693,7 @@ llama_sample_token_mirostat_v2.argtypes = [
|
||||||
llama_token_data_array_p,
|
llama_token_data_array_p,
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
c_float_p,
|
ctypes.POINTER(ctypes.c_float),
|
||||||
]
|
]
|
||||||
llama_sample_token_mirostat_v2.restype = llama_token
|
llama_sample_token_mirostat_v2.restype = llama_token
|
||||||
|
|
||||||
|
@ -2607,8 +2705,10 @@ llama_sample_token_mirostat_v2.restype = llama_token
|
||||||
# llama_token_data_array * candidates);
|
# llama_token_data_array * candidates);
|
||||||
def llama_sample_token_greedy(
|
def llama_sample_token_greedy(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
/
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Selects the token with the highest probability."""
|
"""Selects the token with the highest probability."""
|
||||||
...
|
...
|
||||||
|
@ -2628,8 +2728,10 @@ llama_sample_token_greedy.restype = llama_token
|
||||||
# llama_token_data_array * candidates);
|
# llama_token_data_array * candidates);
|
||||||
def llama_sample_token(
|
def llama_sample_token(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates: Union[
|
||||||
/
|
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||||
|
],
|
||||||
|
/,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Randomly selects a token from the candidates based on their probabilities."""
|
"""Randomly selects a token from the candidates based on their probabilities."""
|
||||||
...
|
...
|
||||||
|
@ -2649,10 +2751,7 @@ llama_sample_token.restype = llama_token
|
||||||
# struct llama_grammar * grammar,
|
# struct llama_grammar * grammar,
|
||||||
# llama_token token);
|
# llama_token token);
|
||||||
def llama_grammar_accept_token(
|
def llama_grammar_accept_token(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], /
|
||||||
grammar: llama_grammar_p,
|
|
||||||
token: Union[llama_token, int],
|
|
||||||
/
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Accepts the sampled token into the grammar"""
|
"""Accepts the sampled token into the grammar"""
|
||||||
...
|
...
|
||||||
|
@ -2711,7 +2810,9 @@ class llama_beams_state(ctypes.Structure):
|
||||||
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||||
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||||
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
||||||
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_beams_state)
|
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(
|
||||||
|
None, ctypes.c_void_p, llama_beams_state
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# /// @details Deterministically returns entire sentence constructed by a beam search.
|
# /// @details Deterministically returns entire sentence constructed by a beam search.
|
||||||
|
@ -2731,12 +2832,12 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_
|
||||||
# int32_t n_predict);
|
# int32_t n_predict);
|
||||||
def llama_beam_search(
|
def llama_beam_search(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
callback: "ctypes._CFuncPtr[None, ctypes.c_void_p, llama_beams_state]", # type: ignore
|
callback: CtypesFuncPointer,
|
||||||
callback_data: ctypes.c_void_p,
|
callback_data: ctypes.c_void_p,
|
||||||
n_beams: Union[ctypes.c_size_t, int],
|
n_beams: Union[ctypes.c_size_t, int],
|
||||||
n_past: Union[ctypes.c_int, int],
|
n_past: Union[ctypes.c_int, int],
|
||||||
n_predict: Union[ctypes.c_int, int],
|
n_predict: Union[ctypes.c_int, int],
|
||||||
/
|
/,
|
||||||
):
|
):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -2806,8 +2907,9 @@ llama_print_system_info.restype = ctypes.c_char_p
|
||||||
# // If this is not called, or NULL is supplied, everything is output on stderr.
|
# // If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||||
def llama_log_set(
|
def llama_log_set(
|
||||||
log_callback: Union["ctypes._FuncPointer", ctypes.c_void_p], user_data: ctypes.c_void_p, # type: ignore
|
log_callback: Optional[CtypesFuncPointer],
|
||||||
/
|
user_data: ctypes.c_void_p, # type: ignore
|
||||||
|
/,
|
||||||
):
|
):
|
||||||
"""Set callback for all future logging events.
|
"""Set callback for all future logging events.
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue