fix: Use numpy recarray for candidates data, fixes bug with temp < 0

This commit is contained in:
Andrei Betlen 2024-06-01 18:09:24 -04:00
parent 2d89964147
commit af3ed503e9

View file

@ -545,13 +545,12 @@ class _LlamaBatch:
class _LlamaTokenDataArray: class _LlamaTokenDataArray:
def __init__(self, *, n_vocab: int): def __init__(self, *, n_vocab: int):
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.candidates_data = np.array( self.candidates_data = np.recarray(
[], (self.n_vocab,),
dtype=np.dtype( dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
), ),
) )
self.candidates_data.resize(3, self.n_vocab, refcheck=False)
self.candidates = llama_cpp.llama_token_data_array( self.candidates = llama_cpp.llama_token_data_array(
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
size=self.n_vocab, size=self.n_vocab,
@ -561,14 +560,11 @@ class _LlamaTokenDataArray:
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
def copy_logits(self, logits: npt.NDArray[np.single]): def copy_logits(self, logits: npt.NDArray[np.single]):
self.candidates_data["id"][:] = self.default_candidates_data_id self.candidates_data.id[:] = self.default_candidates_data_id
self.candidates_data["logit"][:] = logits self.candidates_data.logit[:] = logits
self.candidates_data["p"][:] = self.default_candidates_data_p self.candidates_data.p[:] = self.default_candidates_data_p
self.candidates.data = self.candidates_data.ctypes.data_as( self.candidates.sorted = False
llama_cpp.llama_token_data_p self.candidates.size = self.n_vocab
)
self.candidates.sorted = ctypes.c_bool(False)
self.candidates.size = ctypes.c_size_t(self.n_vocab)
# Python wrappers over common/common # Python wrappers over common/common
@ -759,14 +755,14 @@ class _LlamaSamplingContext:
self.params.penalty_present, self.params.penalty_present,
) )
if not self.params.penalize_nl: if not self.params.penalize_nl:
token_data_array.candidates_data["logit"][nl_token] = nl_logit token_data_array.candidates_data.logit[nl_token] = nl_logit
if self.grammar is not None: if self.grammar is not None:
ctx_main.sample_grammar(token_data_array, self.grammar) ctx_main.sample_grammar(token_data_array, self.grammar)
if self.params.temp < 0: if self.params.temp < 0:
ctx_main.sample_softmax(token_data_array) ctx_main.sample_softmax(token_data_array)
id = token_data_array.candidates_data["id"][0] id = token_data_array.candidates_data.id[0]
elif self.params.temp == 0: elif self.params.temp == 0:
id = ctx_main.sample_token_greedy(token_data_array) id = ctx_main.sample_token_greedy(token_data_array)
else: else: