misc: use typesafe byref for internal classes
This commit is contained in:
parent
a0ce429dc0
commit
b9aca612af
1 changed files with 20 additions and 20 deletions
|
@ -82,7 +82,7 @@ class _LlamaModel:
|
|||
def desc(self) -> str:
|
||||
assert self.model is not None
|
||||
buf = ctypes.create_string_buffer(1024)
|
||||
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
|
||||
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
||||
return buf.value.decode("utf-8")
|
||||
|
||||
def size(self) -> int:
|
||||
|
@ -184,7 +184,7 @@ class _LlamaModel:
|
|||
def token_to_piece(self, token: int) -> bytes:
|
||||
assert self.model is not None
|
||||
buf = ctypes.create_string_buffer(32)
|
||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
|
||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
|
||||
return bytes(buf)
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
|
@ -349,7 +349,7 @@ class _LlamaContext:
|
|||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_repetition_penalties(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
last_tokens_data,
|
||||
penalty_last_n,
|
||||
penalty_repeat,
|
||||
|
@ -367,7 +367,7 @@ class _LlamaContext:
|
|||
assert guidance_ctx.ctx is not None
|
||||
llama_cpp.llama_sample_classifier_free_guidance(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
guidance_ctx.ctx,
|
||||
scale,
|
||||
)
|
||||
|
@ -376,25 +376,25 @@ class _LlamaContext:
|
|||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_softmax(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
)
|
||||
|
||||
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_top_k(
|
||||
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
|
||||
)
|
||||
|
||||
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_top_p(
|
||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||
)
|
||||
|
||||
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_min_p(
|
||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||
)
|
||||
|
||||
def sample_tail_free(
|
||||
|
@ -402,7 +402,7 @@ class _LlamaContext:
|
|||
):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_tail_free(
|
||||
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
|
||||
)
|
||||
|
||||
def sample_typical(
|
||||
|
@ -410,13 +410,13 @@ class _LlamaContext:
|
|||
):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_typical(
|
||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||
)
|
||||
|
||||
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_sample_temp(
|
||||
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
|
||||
self.ctx, llama_cpp.byref(candidates.candidates), temp
|
||||
)
|
||||
|
||||
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
||||
|
@ -424,7 +424,7 @@ class _LlamaContext:
|
|||
assert grammar.grammar is not None
|
||||
llama_cpp.llama_sample_grammar(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
grammar.grammar,
|
||||
)
|
||||
|
||||
|
@ -434,12 +434,12 @@ class _LlamaContext:
|
|||
tau: float,
|
||||
eta: float,
|
||||
m: int,
|
||||
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
|
||||
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
||||
) -> int:
|
||||
assert self.ctx is not None
|
||||
return llama_cpp.llama_sample_token_mirostat(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
tau,
|
||||
eta,
|
||||
m,
|
||||
|
@ -447,12 +447,12 @@ class _LlamaContext:
|
|||
)
|
||||
|
||||
def sample_token_mirostat_v2(
|
||||
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
|
||||
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
|
||||
) -> int:
|
||||
assert self.ctx is not None
|
||||
return llama_cpp.llama_sample_token_mirostat_v2(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
tau,
|
||||
eta,
|
||||
mu,
|
||||
|
@ -462,14 +462,14 @@ class _LlamaContext:
|
|||
assert self.ctx is not None
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
)
|
||||
|
||||
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||
assert self.ctx is not None
|
||||
return llama_cpp.llama_sample_token(
|
||||
self.ctx,
|
||||
ctypes.byref(candidates.candidates), # type: ignore
|
||||
llama_cpp.byref(candidates.candidates),
|
||||
)
|
||||
|
||||
# Grammar
|
||||
|
@ -566,7 +566,7 @@ class _LlamaTokenDataArray:
|
|||
size=self.n_vocab,
|
||||
sorted=False,
|
||||
)
|
||||
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
|
||||
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
||||
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
||||
|
||||
def copy_logits(self, logits: npt.NDArray[np.single]):
|
||||
|
@ -754,7 +754,7 @@ class _LlamaSamplingContext:
|
|||
ctx_main.sample_repetition_penalties(
|
||||
token_data_array,
|
||||
# TODO: Only create this once
|
||||
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
|
||||
(llama_cpp.llama_token * len(self.prev))(*self.prev),
|
||||
self.params.penalty_last_n,
|
||||
self.params.penalty_repeat,
|
||||
self.params.penalty_freq,
|
||||
|
|
Loading…
Add table
Reference in a new issue