fix: float32 is not JSON serializable when streaming logits.
This commit is contained in:
parent
abda047284
commit
a05b4da80a
1 changed files with 2 additions and 2 deletions
|
@ -1555,7 +1555,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens
|
token_offset = len(prompt_tokens) + returned_tokens
|
||||||
logits = self._scores[token_offset - 1, :]
|
logits = self._scores[token_offset - 1, :]
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
zip(current_logprobs, range(len(current_logprobs))),
|
zip(current_logprobs, range(len(current_logprobs))),
|
||||||
|
@ -1674,7 +1674,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
logits = self._scores[token_offset, :]
|
logits = self._scores[token_offset, :]
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
zip(current_logprobs, range(len(current_logprobs))),
|
zip(current_logprobs, range(len(current_logprobs))),
|
||||||
|
|
Loading…
Reference in a new issue