fix: LlamaHFTokenizer now receives pre_tokens

This commit is contained in:
Andrei Betlen 2024-02-23 12:23:24 -05:00
parent ded5d627a5
commit 47bad30dd7
2 changed files with 33 additions and 23 deletions

View file

@ -480,7 +480,7 @@ class Llama:
Returns:
The detokenized string.
"""
return self.tokenizer_.detokenize(tokens, prev_tokens)
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
@ -1016,13 +1016,13 @@ class Llama:
grammar=grammar,
):
if token == self._token_eos:
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop"
break
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
# Contains multi-byte UTF8
for k, char in enumerate(all_text[-3:]):
@ -1046,7 +1046,7 @@ class Llama:
if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
remaining_length = len(remaining_text)
# We want to avoid yielding any characters from
@ -1068,17 +1068,17 @@ class Llama:
for token in remaining_tokens:
if token == self.token_bos():
continue
token_end_position += len(self.detokenize([token]))
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
# Check if stop sequence is in the token
if token_end_position > (
remaining_length - first_stop_position
):
break
token_str = self.detokenize([token]).decode(
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens]).decode(
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
)
@ -1100,7 +1100,7 @@ class Llama:
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
],
@ -1116,7 +1116,7 @@ class Llama:
"model": model_name,
"choices": [
{
"text": self.detokenize([token]).decode(
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
),
"index": 0,
@ -1130,7 +1130,7 @@ class Llama:
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
try:
bs = self.detokenize(remaining_tokens[:i])
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
ts = bs.decode("utf-8")
decode_success = True
break
@ -1165,14 +1165,14 @@ class Llama:
}
if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "length"
break
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop"
if self.verbose:
@ -1180,7 +1180,7 @@ class Llama:
if stream:
remaining_tokens = completion_tokens[returned_tokens:]
all_text = self.detokenize(remaining_tokens)
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
end = min(all_text.index(stop) for stop in any_stop)
@ -1189,7 +1189,7 @@ class Llama:
token_end_position = 0
for token in remaining_tokens:
token_end_position += len(self.detokenize([token]))
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
@ -1199,7 +1199,7 @@ class Llama:
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self._scores[token_offset, :]
@ -1313,8 +1313,8 @@ class Llama:
all_tokens = completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
for i, token in enumerate(all_tokens)
]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim
@ -1339,7 +1339,7 @@ class Llama:
)
token_logprobs.append(logprobs_token[int(token)])
top_logprob: Optional[Dict[str, float]] = {
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: logprobs_token[int(token)]})
@ -1594,6 +1594,8 @@ class Llama:
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]:

View file

@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]:
"""Tokenize the text into tokens.
Args:
text: The text to tokenize.
add_bos: Whether to add a beginning of sequence token.
special: Whether to tokenize text literally or as special tokens."""
raise NotImplementedError
@abc.abstractmethod
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
"""Detokenize the tokens into text.
Args:
tokens: The tokens to detokenize.
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
raise NotImplementedError
@ -37,9 +48,6 @@ class LlamaTokenizer(BaseLlamaTokenizer):
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
if prev_tokens is not None:
return self._model.detokenize(tokens[len(prev_tokens) :])
else:
return self._model.detokenize(tokens)
def encode(
@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer):
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
"utf-8", errors="ignore"
)