fix: LlamaHFTokenizer now receives pre_tokens
This commit is contained in:
parent
ded5d627a5
commit
47bad30dd7
2 changed files with 33 additions and 23 deletions
|
@ -480,7 +480,7 @@ class Llama:
|
|||
Returns:
|
||||
The detokenized string.
|
||||
"""
|
||||
return self.tokenizer_.detokenize(tokens, prev_tokens)
|
||||
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
|
||||
|
||||
def set_cache(self, cache: Optional[BaseLlamaCache]):
|
||||
"""Set the cache.
|
||||
|
@ -1016,13 +1016,13 @@ class Llama:
|
|||
grammar=grammar,
|
||||
):
|
||||
if token == self._token_eos:
|
||||
text = self.detokenize(completion_tokens)
|
||||
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
completion_tokens.append(token)
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||
|
||||
# Contains multi-byte UTF8
|
||||
for k, char in enumerate(all_text[-3:]):
|
||||
|
@ -1046,7 +1046,7 @@ class Llama:
|
|||
|
||||
if stream:
|
||||
remaining_tokens = completion_tokens[returned_tokens:]
|
||||
remaining_text = self.detokenize(remaining_tokens)
|
||||
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||
remaining_length = len(remaining_text)
|
||||
|
||||
# We want to avoid yielding any characters from
|
||||
|
@ -1068,17 +1068,17 @@ class Llama:
|
|||
for token in remaining_tokens:
|
||||
if token == self.token_bos():
|
||||
continue
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position > (
|
||||
remaining_length - first_stop_position
|
||||
):
|
||||
break
|
||||
token_str = self.detokenize([token]).decode(
|
||||
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
self.detokenize(completion_tokens[:returned_tokens]).decode(
|
||||
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
)
|
||||
|
@ -1100,7 +1100,7 @@ class Llama:
|
|||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||
logprobs_or_none = {
|
||||
"tokens": [
|
||||
self.detokenize([token]).decode(
|
||||
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
],
|
||||
|
@ -1116,7 +1116,7 @@ class Llama:
|
|||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": self.detokenize([token]).decode(
|
||||
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
|
@ -1130,7 +1130,7 @@ class Llama:
|
|||
decode_success = False
|
||||
for i in range(1, len(remaining_tokens) + 1):
|
||||
try:
|
||||
bs = self.detokenize(remaining_tokens[:i])
|
||||
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||
ts = bs.decode("utf-8")
|
||||
decode_success = True
|
||||
break
|
||||
|
@ -1165,14 +1165,14 @@ class Llama:
|
|||
}
|
||||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||
finish_reason = "length"
|
||||
break
|
||||
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
self._input_ids, self._scores[-1, :]
|
||||
):
|
||||
text = self.detokenize(completion_tokens)
|
||||
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||
finish_reason = "stop"
|
||||
|
||||
if self.verbose:
|
||||
|
@ -1180,7 +1180,7 @@ class Llama:
|
|||
|
||||
if stream:
|
||||
remaining_tokens = completion_tokens[returned_tokens:]
|
||||
all_text = self.detokenize(remaining_tokens)
|
||||
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||
any_stop = [s for s in stop_sequences if s in all_text]
|
||||
if len(any_stop) > 0:
|
||||
end = min(all_text.index(stop) for stop in any_stop)
|
||||
|
@ -1189,7 +1189,7 @@ class Llama:
|
|||
|
||||
token_end_position = 0
|
||||
for token in remaining_tokens:
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
|
@ -1199,7 +1199,7 @@ class Llama:
|
|||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||
logits = self._scores[token_offset, :]
|
||||
|
@ -1313,8 +1313,8 @@ class Llama:
|
|||
all_tokens = completion_tokens
|
||||
|
||||
all_token_strs = [
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
|
||||
for i, token in enumerate(all_tokens)
|
||||
]
|
||||
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
||||
# TODO: may be able to change this loop to use np.take_along_dim
|
||||
|
@ -1339,7 +1339,7 @@ class Llama:
|
|||
)
|
||||
token_logprobs.append(logprobs_token[int(token)])
|
||||
top_logprob: Optional[Dict[str, float]] = {
|
||||
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
|
||||
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||
|
@ -1594,6 +1594,8 @@ class Llama:
|
|||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
) -> Union[
|
||||
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
||||
]:
|
||||
|
|
|
@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
|
|||
def tokenize(
|
||||
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||
) -> List[int]:
|
||||
"""Tokenize the text into tokens.
|
||||
|
||||
Args:
|
||||
text: The text to tokenize.
|
||||
add_bos: Whether to add a beginning of sequence token.
|
||||
special: Whether to tokenize text literally or as special tokens."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def detokenize(
|
||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||
) -> bytes:
|
||||
"""Detokenize the tokens into text.
|
||||
|
||||
Args:
|
||||
tokens: The tokens to detokenize.
|
||||
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -37,9 +48,6 @@ class LlamaTokenizer(BaseLlamaTokenizer):
|
|||
def detokenize(
|
||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||
) -> bytes:
|
||||
if prev_tokens is not None:
|
||||
return self._model.detokenize(tokens[len(prev_tokens) :])
|
||||
else:
|
||||
return self._model.detokenize(tokens)
|
||||
|
||||
def encode(
|
||||
|
@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer):
|
|||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||
) -> bytes:
|
||||
if prev_tokens is not None:
|
||||
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
||||
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
|
||||
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue