Fix logits_all bug
This commit is contained in:
parent
6ee413d79e
commit
d696251fbe
1 changed files with 2 additions and 2 deletions
|
@ -439,7 +439,7 @@ class Llama:
|
||||||
def eval_logits(self) -> Deque[List[float]]:
|
def eval_logits(self) -> Deque[List[float]]:
|
||||||
return deque(
|
return deque(
|
||||||
self.scores[: self.n_tokens, :].tolist(),
|
self.scores[: self.n_tokens, :].tolist(),
|
||||||
maxlen=self._n_ctx if self.model_params.logits_all else 1,
|
maxlen=self._n_ctx if self.context_params.logits_all else 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
|
@ -964,7 +964,7 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
|
|
||||||
if logprobs is not None and self.model_params.logits_all is False:
|
if logprobs is not None and self.context_params.logits_all is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"logprobs is not supported for models created with logits_all=False"
|
"logprobs is not supported for models created with logits_all=False"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue