Clean up logprobs implementation
This commit is contained in:
parent
26cc4ee029
commit
6153baab2d
1 changed files with 39 additions and 67 deletions
|
@ -351,55 +351,19 @@ class Llama:
|
|||
else:
|
||||
stop_sequences = []
|
||||
|
||||
text_offset = 0
|
||||
text_offsets: List[int] = []
|
||||
token_logprobs: List[float] = []
|
||||
tokens: List[str] = []
|
||||
top_logprobs: List[Dict[str, float]] = []
|
||||
|
||||
self.reset()
|
||||
self.eval(prompt_tokens)
|
||||
|
||||
if logprobs is not None and self.params.logits_all is False:
|
||||
raise ValueError(
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
if logprobs is not None:
|
||||
token_strs = [
|
||||
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
|
||||
]
|
||||
logprobs_all = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
for row in self.all_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
prompt_tokens, token_strs, logprobs_all
|
||||
):
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
top_logprobs.append(top_logprob)
|
||||
|
||||
finish_reason = "length"
|
||||
while True:
|
||||
token = self.sample(
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
|
@ -443,34 +407,10 @@ class Llama:
|
|||
],
|
||||
}
|
||||
|
||||
if logprobs is not None:
|
||||
# TODO: Confirm wether this should happen before or after
|
||||
# next eval.
|
||||
token_str = self.detokenize([token]).decode("utf-8")
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
logprobs_token = [
|
||||
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
|
||||
]
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||
top_logprobs.append(top_logprob)
|
||||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "length"
|
||||
break
|
||||
self.eval([token])
|
||||
|
||||
if stream:
|
||||
yield {
|
||||
|
@ -499,6 +439,38 @@ class Llama:
|
|||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
text_offset = 0
|
||||
text_offsets: List[int] = []
|
||||
token_logprobs: List[float] = []
|
||||
tokens: List[str] = []
|
||||
top_logprobs: List[Dict[str, float]] = []
|
||||
|
||||
all_tokens = prompt_tokens + completion_tokens
|
||||
all_token_strs = [
|
||||
self.detokenize([token]).decode("utf-8") for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
for row in self.all_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
top_logprobs.append(top_logprob)
|
||||
logprobs_or_none = {
|
||||
"tokens": tokens,
|
||||
"text_offset": text_offsets,
|
||||
|
|
Loading…
Reference in a new issue