misc: use typesafe byref for internal classes
This commit is contained in:
parent
a0ce429dc0
commit
b9aca612af
1 changed files with 20 additions and 20 deletions
|
@ -82,7 +82,7 @@ class _LlamaModel:
|
||||||
def desc(self) -> str:
|
def desc(self) -> str:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
buf = ctypes.create_string_buffer(1024)
|
buf = ctypes.create_string_buffer(1024)
|
||||||
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
|
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
||||||
return buf.value.decode("utf-8")
|
return buf.value.decode("utf-8")
|
||||||
|
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
|
@ -184,7 +184,7 @@ class _LlamaModel:
|
||||||
def token_to_piece(self, token: int) -> bytes:
|
def token_to_piece(self, token: int) -> bytes:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
buf = ctypes.create_string_buffer(32)
|
buf = ctypes.create_string_buffer(32)
|
||||||
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
|
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
|
||||||
return bytes(buf)
|
return bytes(buf)
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int]) -> bytes:
|
def detokenize(self, tokens: List[int]) -> bytes:
|
||||||
|
@ -349,7 +349,7 @@ class _LlamaContext:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_repetition_penalties(
|
llama_cpp.llama_sample_repetition_penalties(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
last_tokens_data,
|
last_tokens_data,
|
||||||
penalty_last_n,
|
penalty_last_n,
|
||||||
penalty_repeat,
|
penalty_repeat,
|
||||||
|
@ -367,7 +367,7 @@ class _LlamaContext:
|
||||||
assert guidance_ctx.ctx is not None
|
assert guidance_ctx.ctx is not None
|
||||||
llama_cpp.llama_sample_classifier_free_guidance(
|
llama_cpp.llama_sample_classifier_free_guidance(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
guidance_ctx.ctx,
|
guidance_ctx.ctx,
|
||||||
scale,
|
scale,
|
||||||
)
|
)
|
||||||
|
@ -376,25 +376,25 @@ class _LlamaContext:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_softmax(
|
llama_cpp.llama_sample_softmax(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_top_k(
|
llama_cpp.llama_sample_top_k(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_top_p(
|
llama_cpp.llama_sample_top_p(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_min_p(
|
llama_cpp.llama_sample_min_p(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_tail_free(
|
def sample_tail_free(
|
||||||
|
@ -402,7 +402,7 @@ class _LlamaContext:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_tail_free(
|
llama_cpp.llama_sample_tail_free(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_typical(
|
def sample_typical(
|
||||||
|
@ -410,13 +410,13 @@ class _LlamaContext:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_typical(
|
llama_cpp.llama_sample_typical(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_temp(
|
llama_cpp.llama_sample_temp(
|
||||||
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
|
self.ctx, llama_cpp.byref(candidates.candidates), temp
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
||||||
|
@ -424,7 +424,7 @@ class _LlamaContext:
|
||||||
assert grammar.grammar is not None
|
assert grammar.grammar is not None
|
||||||
llama_cpp.llama_sample_grammar(
|
llama_cpp.llama_sample_grammar(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
grammar.grammar,
|
grammar.grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -434,12 +434,12 @@ class _LlamaContext:
|
||||||
tau: float,
|
tau: float,
|
||||||
eta: float,
|
eta: float,
|
||||||
m: int,
|
m: int,
|
||||||
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
|
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
||||||
) -> int:
|
) -> int:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
return llama_cpp.llama_sample_token_mirostat(
|
return llama_cpp.llama_sample_token_mirostat(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
tau,
|
tau,
|
||||||
eta,
|
eta,
|
||||||
m,
|
m,
|
||||||
|
@ -447,12 +447,12 @@ class _LlamaContext:
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_token_mirostat_v2(
|
def sample_token_mirostat_v2(
|
||||||
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
|
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
|
||||||
) -> int:
|
) -> int:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
return llama_cpp.llama_sample_token_mirostat_v2(
|
return llama_cpp.llama_sample_token_mirostat_v2(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
tau,
|
tau,
|
||||||
eta,
|
eta,
|
||||||
mu,
|
mu,
|
||||||
|
@ -462,14 +462,14 @@ class _LlamaContext:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
return llama_cpp.llama_sample_token(
|
return llama_cpp.llama_sample_token(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
ctypes.byref(candidates.candidates), # type: ignore
|
llama_cpp.byref(candidates.candidates),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Grammar
|
# Grammar
|
||||||
|
@ -566,7 +566,7 @@ class _LlamaTokenDataArray:
|
||||||
size=self.n_vocab,
|
size=self.n_vocab,
|
||||||
sorted=False,
|
sorted=False,
|
||||||
)
|
)
|
||||||
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
|
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
||||||
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
||||||
|
|
||||||
def copy_logits(self, logits: npt.NDArray[np.single]):
|
def copy_logits(self, logits: npt.NDArray[np.single]):
|
||||||
|
@ -754,7 +754,7 @@ class _LlamaSamplingContext:
|
||||||
ctx_main.sample_repetition_penalties(
|
ctx_main.sample_repetition_penalties(
|
||||||
token_data_array,
|
token_data_array,
|
||||||
# TODO: Only create this once
|
# TODO: Only create this once
|
||||||
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
|
(llama_cpp.llama_token * len(self.prev))(*self.prev),
|
||||||
self.params.penalty_last_n,
|
self.params.penalty_last_n,
|
||||||
self.params.penalty_repeat,
|
self.params.penalty_repeat,
|
||||||
self.params.penalty_freq,
|
self.params.penalty_freq,
|
||||||
|
|
Loading…
Reference in a new issue