Clean up logprobs implementation
This commit is contained in:
parent
26cc4ee029
commit
6153baab2d
1 changed files with 39 additions and 67 deletions
|
@ -351,55 +351,19 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
|
|
||||||
text_offset = 0
|
|
||||||
text_offsets: List[int] = []
|
|
||||||
token_logprobs: List[float] = []
|
|
||||||
tokens: List[str] = []
|
|
||||||
top_logprobs: List[Dict[str, float]] = []
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
self.eval(prompt_tokens)
|
|
||||||
|
|
||||||
if logprobs is not None and self.params.logits_all is False:
|
if logprobs is not None and self.params.logits_all is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"logprobs is not supported for models created with logits_all=False"
|
"logprobs is not supported for models created with logits_all=False"
|
||||||
)
|
)
|
||||||
|
|
||||||
if logprobs is not None:
|
|
||||||
token_strs = [
|
|
||||||
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
|
|
||||||
]
|
|
||||||
logprobs_all = [
|
|
||||||
[Llama.logit_to_logprob(logit) for logit in row]
|
|
||||||
for row in self.all_logits
|
|
||||||
]
|
|
||||||
for token, token_str, logprobs_token in zip(
|
|
||||||
prompt_tokens, token_strs, logprobs_all
|
|
||||||
):
|
|
||||||
text_offsets.append(text_offset)
|
|
||||||
text_offset += len(token_str)
|
|
||||||
tokens.append(token_str)
|
|
||||||
sorted_logprobs = list(
|
|
||||||
sorted(
|
|
||||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
|
||||||
top_logprob = {
|
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
|
||||||
}
|
|
||||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
|
||||||
top_logprobs.append(top_logprob)
|
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
while True:
|
for token in self.generate(
|
||||||
token = self.sample(
|
prompt_tokens,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
)
|
):
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
@ -443,34 +407,10 @@ class Llama:
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
if logprobs is not None:
|
|
||||||
# TODO: Confirm wether this should happen before or after
|
|
||||||
# next eval.
|
|
||||||
token_str = self.detokenize([token]).decode("utf-8")
|
|
||||||
text_offsets.append(text_offset)
|
|
||||||
text_offset += len(token_str)
|
|
||||||
tokens.append(token_str)
|
|
||||||
logprobs_token = [
|
|
||||||
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
|
|
||||||
]
|
|
||||||
sorted_logprobs = list(
|
|
||||||
sorted(
|
|
||||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
|
||||||
top_logprob = {
|
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
|
||||||
}
|
|
||||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
|
||||||
top_logprobs.append(top_logprob)
|
|
||||||
|
|
||||||
if len(completion_tokens) >= max_tokens:
|
if len(completion_tokens) >= max_tokens:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
break
|
break
|
||||||
self.eval([token])
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield {
|
yield {
|
||||||
|
@ -499,6 +439,38 @@ class Llama:
|
||||||
|
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
|
text_offset = 0
|
||||||
|
text_offsets: List[int] = []
|
||||||
|
token_logprobs: List[float] = []
|
||||||
|
tokens: List[str] = []
|
||||||
|
top_logprobs: List[Dict[str, float]] = []
|
||||||
|
|
||||||
|
all_tokens = prompt_tokens + completion_tokens
|
||||||
|
all_token_strs = [
|
||||||
|
self.detokenize([token]).decode("utf-8") for token in all_tokens
|
||||||
|
]
|
||||||
|
all_logprobs = [
|
||||||
|
[Llama.logit_to_logprob(logit) for logit in row]
|
||||||
|
for row in self.all_logits
|
||||||
|
]
|
||||||
|
for token, token_str, logprobs_token in zip(
|
||||||
|
all_tokens, all_token_strs, all_logprobs
|
||||||
|
):
|
||||||
|
text_offsets.append(text_offset)
|
||||||
|
text_offset += len(token_str)
|
||||||
|
tokens.append(token_str)
|
||||||
|
sorted_logprobs = list(
|
||||||
|
sorted(
|
||||||
|
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
|
top_logprob = {
|
||||||
|
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||||
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
}
|
||||||
|
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||||
|
top_logprobs.append(top_logprob)
|
||||||
logprobs_or_none = {
|
logprobs_or_none = {
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"text_offset": text_offsets,
|
"text_offset": text_offsets,
|
||||||
|
|
Loading…
Reference in a new issue