Convert missed llama.cpp constants into standard python types
This commit is contained in:
parent
c4c440ba2d
commit
517f9ed80b
2 changed files with 86 additions and 86 deletions
|
@ -343,11 +343,11 @@ class Llama:
|
|||
if self.lora_path:
|
||||
if llama_cpp.llama_model_apply_lora_from_file(
|
||||
self.model,
|
||||
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
|
||||
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
|
||||
self.lora_path.encode("utf-8"),
|
||||
self.lora_base.encode("utf-8")
|
||||
if self.lora_base is not None
|
||||
else llama_cpp.c_char_p(0),
|
||||
llama_cpp.c_int(self.n_threads),
|
||||
self.n_threads,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
|
||||
|
@ -358,8 +358,8 @@ class Llama:
|
|||
|
||||
self._n_vocab = self.n_vocab()
|
||||
self._n_ctx = self.n_ctx()
|
||||
size = llama_cpp.c_size_t(self._n_vocab)
|
||||
sorted = llama_cpp.c_bool(False)
|
||||
size = self._n_vocab
|
||||
sorted = False
|
||||
self._candidates_data = np.array(
|
||||
[],
|
||||
dtype=np.dtype(
|
||||
|
@ -422,8 +422,8 @@ class Llama:
|
|||
self.model,
|
||||
text,
|
||||
tokens,
|
||||
llama_cpp.c_int(n_ctx),
|
||||
llama_cpp.c_bool(add_bos),
|
||||
n_ctx,
|
||||
add_bos,
|
||||
)
|
||||
if n_tokens < 0:
|
||||
n_tokens = abs(n_tokens)
|
||||
|
@ -432,8 +432,8 @@ class Llama:
|
|||
self.model,
|
||||
text,
|
||||
tokens,
|
||||
llama_cpp.c_int(n_tokens),
|
||||
llama_cpp.c_bool(add_bos),
|
||||
n_tokens,
|
||||
add_bos,
|
||||
)
|
||||
if n_tokens < 0:
|
||||
raise RuntimeError(
|
||||
|
@ -491,9 +491,9 @@ class Llama:
|
|||
return_code = llama_cpp.llama_eval(
|
||||
ctx=self.ctx,
|
||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||
n_tokens=llama_cpp.c_int(n_tokens),
|
||||
n_past=llama_cpp.c_int(n_past),
|
||||
n_threads=llama_cpp.c_int(self.n_threads),
|
||||
n_tokens=n_tokens,
|
||||
n_past=n_past,
|
||||
n_threads=self.n_threads,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
|
@ -514,17 +514,17 @@ class Llama:
|
|||
def _sample(
|
||||
self,
|
||||
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
||||
last_n_tokens_size: llama_cpp.c_int,
|
||||
top_k: llama_cpp.c_int,
|
||||
top_p: llama_cpp.c_float,
|
||||
temp: llama_cpp.c_float,
|
||||
tfs_z: llama_cpp.c_float,
|
||||
repeat_penalty: llama_cpp.c_float,
|
||||
frequency_penalty: llama_cpp.c_float,
|
||||
presence_penalty: llama_cpp.c_float,
|
||||
mirostat_mode: llama_cpp.c_int,
|
||||
mirostat_tau: llama_cpp.c_float,
|
||||
mirostat_eta: llama_cpp.c_float,
|
||||
last_n_tokens_size: int,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temp: float,
|
||||
tfs_z: float,
|
||||
repeat_penalty: float,
|
||||
frequency_penalty: float,
|
||||
presence_penalty: float,
|
||||
mirostat_mode: float,
|
||||
mirostat_tau: float,
|
||||
mirostat_eta: float,
|
||||
penalize_nl: bool = True,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
|
@ -533,10 +533,10 @@ class Llama:
|
|||
assert self.n_tokens > 0
|
||||
n_vocab = self._n_vocab
|
||||
n_ctx = self._n_ctx
|
||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
top_k = n_vocab if top_k <= 0 else top_k
|
||||
last_n_tokens_size = (
|
||||
llama_cpp.c_int(n_ctx)
|
||||
if last_n_tokens_size.value < 0
|
||||
n_ctx
|
||||
if last_n_tokens_size < 0
|
||||
else last_n_tokens_size
|
||||
)
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
|
@ -578,13 +578,13 @@ class Llama:
|
|||
grammar=grammar.grammar,
|
||||
)
|
||||
|
||||
if temp.value == 0.0:
|
||||
if temp == 0.0:
|
||||
id = llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
)
|
||||
elif mirostat_mode.value == 1:
|
||||
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
|
||||
elif mirostat_mode == 1:
|
||||
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau)
|
||||
mirostat_m = llama_cpp.c_int(100)
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
|
@ -599,8 +599,8 @@ class Llama:
|
|||
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
|
||||
m=mirostat_m,
|
||||
)
|
||||
elif mirostat_mode.value == 2:
|
||||
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
|
||||
elif mirostat_mode== 2:
|
||||
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau)
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
|
@ -690,17 +690,17 @@ class Llama:
|
|||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
),
|
||||
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
||||
top_k=llama_cpp.c_int(top_k),
|
||||
top_p=llama_cpp.c_float(top_p),
|
||||
temp=llama_cpp.c_float(temp),
|
||||
tfs_z=llama_cpp.c_float(tfs_z),
|
||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||
frequency_penalty=llama_cpp.c_float(frequency_penalty),
|
||||
presence_penalty=llama_cpp.c_float(presence_penalty),
|
||||
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
tfs_z=tfs_z,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
penalize_nl=penalize_nl,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
|
|
|
@ -91,15 +91,15 @@ GGML_CUDA_MAX_DEVICES = 16
|
|||
LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1
|
||||
|
||||
# define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
LLAMA_DEFAULT_SEED = ctypes.c_int(0xFFFFFFFF)
|
||||
LLAMA_DEFAULT_SEED = 0xFFFFFFFF
|
||||
|
||||
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E)
|
||||
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
|
||||
|
||||
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
|
||||
# define LLAMA_SESSION_VERSION 1
|
||||
LLAMA_SESSION_VERSION = ctypes.c_int(1)
|
||||
LLAMA_SESSION_VERSION = 1
|
||||
|
||||
|
||||
# struct llama_model;
|
||||
|
@ -118,16 +118,16 @@ llama_token_p = POINTER(llama_token)
|
|||
# LLAMA_LOG_LEVEL_WARN = 3,
|
||||
# LLAMA_LOG_LEVEL_INFO = 4
|
||||
# };
|
||||
LLAMA_LOG_LEVEL_ERROR = c_int(2)
|
||||
LLAMA_LOG_LEVEL_WARN = c_int(3)
|
||||
LLAMA_LOG_LEVEL_INFO = c_int(4)
|
||||
LLAMA_LOG_LEVEL_ERROR = 2
|
||||
LLAMA_LOG_LEVEL_WARN = 3
|
||||
LLAMA_LOG_LEVEL_INFO = 4
|
||||
|
||||
# enum llama_vocab_type {
|
||||
# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
# };
|
||||
LLAMA_VOCAB_TYPE_SPM = c_int(0)
|
||||
LLAMA_VOCAB_TYPE_BPE = c_int(1)
|
||||
LLAMA_VOCAB_TYPE_SPM = 0
|
||||
LLAMA_VOCAB_TYPE_BPE = 1
|
||||
|
||||
|
||||
# enum llama_token_type {
|
||||
|
@ -139,13 +139,13 @@ LLAMA_VOCAB_TYPE_BPE = c_int(1)
|
|||
# LLAMA_TOKEN_TYPE_UNUSED = 5,
|
||||
# LLAMA_TOKEN_TYPE_BYTE = 6,
|
||||
# };
|
||||
LLAMA_TOKEN_TYPE_UNDEFINED = c_int(0)
|
||||
LLAMA_TOKEN_TYPE_NORMAL = c_int(1)
|
||||
LLAMA_TOKEN_TYPE_UNKNOWN = c_int(2)
|
||||
LLAMA_TOKEN_TYPE_CONTROL = c_int(3)
|
||||
LLAMA_TOKEN_TYPE_USER_DEFINED = c_int(4)
|
||||
LLAMA_TOKEN_TYPE_UNUSED = c_int(5)
|
||||
LLAMA_TOKEN_TYPE_BYTE = c_int(6)
|
||||
LLAMA_TOKEN_TYPE_UNDEFINED = 0
|
||||
LLAMA_TOKEN_TYPE_NORMAL = 1
|
||||
LLAMA_TOKEN_TYPE_UNKNOWN = 2
|
||||
LLAMA_TOKEN_TYPE_CONTROL = 3
|
||||
LLAMA_TOKEN_TYPE_USER_DEFINED = 4
|
||||
LLAMA_TOKEN_TYPE_UNUSED = 5
|
||||
LLAMA_TOKEN_TYPE_BYTE = 6
|
||||
|
||||
# enum llama_ftype {
|
||||
# LLAMA_FTYPE_ALL_F32 = 0,
|
||||
|
@ -170,24 +170,24 @@ LLAMA_TOKEN_TYPE_BYTE = c_int(6)
|
|||
#
|
||||
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
# };
|
||||
LLAMA_FTYPE_ALL_F32 = c_int(0)
|
||||
LLAMA_FTYPE_MOSTLY_F16 = c_int(1)
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2)
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3)
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4)
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7)
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8)
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9)
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10)
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11)
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12)
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13)
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14)
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15)
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16)
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17)
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18)
|
||||
LLAMA_FTYPE_GUESSED = c_int(1024)
|
||||
LLAMA_FTYPE_ALL_F32 = 0
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K = 10
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18
|
||||
LLAMA_FTYPE_GUESSED = 1024
|
||||
|
||||
|
||||
# typedef struct llama_token_data {
|
||||
|
@ -589,7 +589,7 @@ _lib.llama_model_n_embd.restype = c_int
|
|||
|
||||
# // Get a string describing the model type
|
||||
# LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int:
|
||||
def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int]) -> int:
|
||||
return _lib.llama_model_desc(model, buf, buf_size)
|
||||
|
||||
|
||||
|
@ -957,8 +957,8 @@ def llama_tokenize(
|
|||
ctx: llama_context_p,
|
||||
text: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
n_max_tokens: c_int,
|
||||
add_bos: c_bool,
|
||||
n_max_tokens: Union[c_int, int],
|
||||
add_bos: Union[c_bool, int],
|
||||
) -> int:
|
||||
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
|
||||
|
||||
|
@ -977,8 +977,8 @@ def llama_tokenize_with_model(
|
|||
model: llama_model_p,
|
||||
text: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
n_max_tokens: c_int,
|
||||
add_bos: c_bool,
|
||||
n_max_tokens: Union[c_int, int],
|
||||
add_bos: Union[c_bool, bool],
|
||||
) -> int:
|
||||
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
|
||||
|
||||
|
@ -1003,7 +1003,7 @@ _lib.llama_tokenize_with_model.restype = c_int
|
|||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_piece(
|
||||
ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int
|
||||
ctx: llama_context_p, token: llama_token, buf: bytes, length: Union[c_int, int]
|
||||
) -> int:
|
||||
return _lib.llama_token_to_piece(ctx, token, buf, length)
|
||||
|
||||
|
@ -1018,7 +1018,7 @@ _lib.llama_token_to_piece.restype = c_int
|
|||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_piece_with_model(
|
||||
model: llama_model_p, token: llama_token, buf: bytes, length: c_int
|
||||
model: llama_model_p, token: llama_token, buf: bytes, length: Union[c_int, int]
|
||||
) -> int:
|
||||
return _lib.llama_token_to_piece_with_model(model, token, buf, length)
|
||||
|
||||
|
@ -1453,10 +1453,10 @@ def llama_beam_search(
|
|||
ctx: llama_context_p,
|
||||
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
|
||||
callback_data: c_void_p,
|
||||
n_beams: c_size_t,
|
||||
n_past: c_int,
|
||||
n_predict: c_int,
|
||||
n_threads: c_int,
|
||||
n_beams: Union[c_size_t, int],
|
||||
n_past: Union[c_int, int],
|
||||
n_predict: Union[c_int, int],
|
||||
n_threads: Union[c_int, int],
|
||||
):
|
||||
return _lib.llama_beam_search(
|
||||
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
|
||||
|
|
Loading…
Add table
Reference in a new issue