perf: avoid allocating new buffers during sampling
This commit is contained in:
parent
7887376bff
commit
11eae75211
1 changed files with 4 additions and 2 deletions
|
@ -324,6 +324,8 @@ class Llama:
|
|||
self._candidates = candidates
|
||||
self._token_nl = Llama.token_nl()
|
||||
self._token_eos = Llama.token_eos()
|
||||
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
|
||||
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
|
||||
|
||||
self.n_tokens = 0
|
||||
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
|
||||
|
@ -487,9 +489,9 @@ class Llama:
|
|||
nl_logit = logits[self._token_nl]
|
||||
candidates = self._candidates
|
||||
candidates_data = self._candidates_data
|
||||
candidates_data["id"][:] = np.arange(n_vocab, dtype=np.intc) # type: ignore
|
||||
candidates_data["id"][:] = self._candidates_data_id # type: ignore
|
||||
candidates_data["logit"][:] = logits
|
||||
candidates_data["p"][:] = np.zeros(n_vocab, dtype=np.single)
|
||||
candidates_data["p"][:] = self._candidates_data_p # type: ignore
|
||||
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
|
||||
candidates.sorted = llama_cpp.c_bool(False)
|
||||
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||
|
|
Loading…
Reference in a new issue