fix: Use numpy recarray for candidates data, fixes bug with temp < 0
This commit is contained in:
parent
2d89964147
commit
af3ed503e9
1 changed files with 9 additions and 13 deletions
|
@ -545,13 +545,12 @@ class _LlamaBatch:
|
||||||
class _LlamaTokenDataArray:
|
class _LlamaTokenDataArray:
|
||||||
def __init__(self, *, n_vocab: int):
|
def __init__(self, *, n_vocab: int):
|
||||||
self.n_vocab = n_vocab
|
self.n_vocab = n_vocab
|
||||||
self.candidates_data = np.array(
|
self.candidates_data = np.recarray(
|
||||||
[],
|
(self.n_vocab,),
|
||||||
dtype=np.dtype(
|
dtype=np.dtype(
|
||||||
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.candidates_data.resize(3, self.n_vocab, refcheck=False)
|
|
||||||
self.candidates = llama_cpp.llama_token_data_array(
|
self.candidates = llama_cpp.llama_token_data_array(
|
||||||
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
||||||
size=self.n_vocab,
|
size=self.n_vocab,
|
||||||
|
@ -561,14 +560,11 @@ class _LlamaTokenDataArray:
|
||||||
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
||||||
|
|
||||||
def copy_logits(self, logits: npt.NDArray[np.single]):
|
def copy_logits(self, logits: npt.NDArray[np.single]):
|
||||||
self.candidates_data["id"][:] = self.default_candidates_data_id
|
self.candidates_data.id[:] = self.default_candidates_data_id
|
||||||
self.candidates_data["logit"][:] = logits
|
self.candidates_data.logit[:] = logits
|
||||||
self.candidates_data["p"][:] = self.default_candidates_data_p
|
self.candidates_data.p[:] = self.default_candidates_data_p
|
||||||
self.candidates.data = self.candidates_data.ctypes.data_as(
|
self.candidates.sorted = False
|
||||||
llama_cpp.llama_token_data_p
|
self.candidates.size = self.n_vocab
|
||||||
)
|
|
||||||
self.candidates.sorted = ctypes.c_bool(False)
|
|
||||||
self.candidates.size = ctypes.c_size_t(self.n_vocab)
|
|
||||||
|
|
||||||
|
|
||||||
# Python wrappers over common/common
|
# Python wrappers over common/common
|
||||||
|
@ -759,14 +755,14 @@ class _LlamaSamplingContext:
|
||||||
self.params.penalty_present,
|
self.params.penalty_present,
|
||||||
)
|
)
|
||||||
if not self.params.penalize_nl:
|
if not self.params.penalize_nl:
|
||||||
token_data_array.candidates_data["logit"][nl_token] = nl_logit
|
token_data_array.candidates_data.logit[nl_token] = nl_logit
|
||||||
|
|
||||||
if self.grammar is not None:
|
if self.grammar is not None:
|
||||||
ctx_main.sample_grammar(token_data_array, self.grammar)
|
ctx_main.sample_grammar(token_data_array, self.grammar)
|
||||||
|
|
||||||
if self.params.temp < 0:
|
if self.params.temp < 0:
|
||||||
ctx_main.sample_softmax(token_data_array)
|
ctx_main.sample_softmax(token_data_array)
|
||||||
id = token_data_array.candidates_data["id"][0]
|
id = token_data_array.candidates_data.id[0]
|
||||||
elif self.params.temp == 0:
|
elif self.params.temp == 0:
|
||||||
id = ctx_main.sample_token_greedy(token_data_array)
|
id = ctx_main.sample_token_greedy(token_data_array)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue