Fix logits_all bug
This commit is contained in:
parent
6ee413d79e
commit
d696251fbe
1 changed files with 2 additions and 2 deletions
|
@ -439,7 +439,7 @@ class Llama:
|
|||
def eval_logits(self) -> Deque[List[float]]:
|
||||
return deque(
|
||||
self.scores[: self.n_tokens, :].tolist(),
|
||||
maxlen=self._n_ctx if self.model_params.logits_all else 1,
|
||||
maxlen=self._n_ctx if self.context_params.logits_all else 1,
|
||||
)
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||
|
@ -964,7 +964,7 @@ class Llama:
|
|||
else:
|
||||
stop_sequences = []
|
||||
|
||||
if logprobs is not None and self.model_params.logits_all is False:
|
||||
if logprobs is not None and self.context_params.logits_all is False:
|
||||
raise ValueError(
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue