Format
This commit is contained in:
parent
34c505edf2
commit
4f2b5d0b53
1 changed files with 9 additions and 5 deletions
|
@ -324,7 +324,7 @@ class Llama:
|
|||
self._candidates = candidates
|
||||
self._token_nl = Llama.token_nl()
|
||||
self._token_eos = Llama.token_eos()
|
||||
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
|
||||
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
|
||||
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
|
||||
|
||||
self.n_tokens = 0
|
||||
|
@ -445,8 +445,12 @@ class Llama:
|
|||
# Save logits
|
||||
rows = n_tokens if self.params.logits_all else 1
|
||||
cols = self._n_vocab
|
||||
offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
|
||||
self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
|
||||
offset = (
|
||||
0 if self.params.logits_all else n_tokens - 1
|
||||
) # NOTE: Only save the last token logits if logits_all is False
|
||||
self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape(
|
||||
-1
|
||||
)[:] = llama_cpp.llama_get_logits(self.ctx)[: rows * cols]
|
||||
# Update n_tokens
|
||||
self.n_tokens += n_tokens
|
||||
|
||||
|
@ -491,7 +495,7 @@ class Llama:
|
|||
candidates_data = self._candidates_data
|
||||
candidates_data["id"][:] = self._candidates_data_id # type: ignore
|
||||
candidates_data["logit"][:] = logits
|
||||
candidates_data["p"][:] = self._candidates_data_p # type: ignore
|
||||
candidates_data["p"][:] = self._candidates_data_p # type: ignore
|
||||
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
|
||||
candidates.sorted = llama_cpp.c_bool(False)
|
||||
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||
|
@ -537,7 +541,7 @@ class Llama:
|
|||
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
temp=temp,
|
||||
)
|
||||
return llama_cpp.llama_sample_token_mirostat_v2(
|
||||
|
|
Loading…
Reference in a new issue