misc: use typesafe byref for internal classes

This commit is contained in:
Andrei Betlen 2024-02-23 03:40:07 -05:00
parent a0ce429dc0
commit b9aca612af

View file

@ -82,7 +82,7 @@ class _LlamaModel:
def desc(self) -> str:
assert self.model is not None
buf = ctypes.create_string_buffer(1024)
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
llama_cpp.llama_model_desc(self.model, buf, 1024)
return buf.value.decode("utf-8")
def size(self) -> int:
@ -184,7 +184,7 @@ class _LlamaModel:
def token_to_piece(self, token: int) -> bytes:
assert self.model is not None
buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
return bytes(buf)
def detokenize(self, tokens: List[int]) -> bytes:
@ -349,7 +349,7 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_repetition_penalties(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
last_tokens_data,
penalty_last_n,
penalty_repeat,
@ -367,7 +367,7 @@ class _LlamaContext:
assert guidance_ctx.ctx is not None
llama_cpp.llama_sample_classifier_free_guidance(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
guidance_ctx.ctx,
scale,
)
@ -376,25 +376,25 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_softmax(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_k(
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
)
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_p(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_min_p(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_tail_free(
@ -402,7 +402,7 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_tail_free(
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
)
def sample_typical(
@ -410,13 +410,13 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_typical(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
assert self.ctx is not None
llama_cpp.llama_sample_temp(
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), temp
)
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
@ -424,7 +424,7 @@ class _LlamaContext:
assert grammar.grammar is not None
llama_cpp.llama_sample_grammar(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
grammar.grammar,
)
@ -434,12 +434,12 @@ class _LlamaContext:
tau: float,
eta: float,
m: int,
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
tau,
eta,
m,
@ -447,12 +447,12 @@ class _LlamaContext:
)
def sample_token_mirostat_v2(
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat_v2(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
tau,
eta,
mu,
@ -462,14 +462,14 @@ class _LlamaContext:
assert self.ctx is not None
return llama_cpp.llama_sample_token_greedy(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
# Grammar
@ -566,7 +566,7 @@ class _LlamaTokenDataArray:
size=self.n_vocab,
sorted=False,
)
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
def copy_logits(self, logits: npt.NDArray[np.single]):
@ -754,7 +754,7 @@ class _LlamaSamplingContext:
ctx_main.sample_repetition_penalties(
token_data_array,
# TODO: Only create this once
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
(llama_cpp.llama_token * len(self.prev))(*self.prev),
self.params.penalty_last_n,
self.params.penalty_repeat,
self.params.penalty_freq,