fix: LlamaHFTokenizer now receives pre_tokens

This commit is contained in:
Andrei Betlen 2024-02-23 12:23:24 -05:00
parent ded5d627a5
commit 47bad30dd7
2 changed files with 33 additions and 23 deletions

View file

@ -480,7 +480,7 @@ class Llama:
Returns: Returns:
The detokenized string. The detokenized string.
""" """
return self.tokenizer_.detokenize(tokens, prev_tokens) return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
def set_cache(self, cache: Optional[BaseLlamaCache]): def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache. """Set the cache.
@ -1016,13 +1016,13 @@ class Llama:
grammar=grammar, grammar=grammar,
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop" finish_reason = "stop"
break break
completion_tokens.append(token) completion_tokens.append(token)
all_text = self.detokenize(completion_tokens) all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
# Contains multi-byte UTF8 # Contains multi-byte UTF8
for k, char in enumerate(all_text[-3:]): for k, char in enumerate(all_text[-3:]):
@ -1046,7 +1046,7 @@ class Llama:
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens) remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
remaining_length = len(remaining_text) remaining_length = len(remaining_text)
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
@ -1068,17 +1068,17 @@ class Llama:
for token in remaining_tokens: for token in remaining_tokens:
if token == self.token_bos(): if token == self.token_bos():
continue continue
token_end_position += len(self.detokenize([token])) token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position > ( if token_end_position > (
remaining_length - first_stop_position remaining_length - first_stop_position
): ):
break break
token_str = self.detokenize([token]).decode( token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens]).decode( self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
) )
@ -1100,7 +1100,7 @@ class Llama:
top_logprob.update({token_str: current_logprobs[int(token)]}) top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = { logprobs_or_none = {
"tokens": [ "tokens": [
self.detokenize([token]).decode( self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
], ],
@ -1116,7 +1116,7 @@ class Llama:
"model": model_name, "model": model_name,
"choices": [ "choices": [
{ {
"text": self.detokenize([token]).decode( "text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
), ),
"index": 0, "index": 0,
@ -1130,7 +1130,7 @@ class Llama:
decode_success = False decode_success = False
for i in range(1, len(remaining_tokens) + 1): for i in range(1, len(remaining_tokens) + 1):
try: try:
bs = self.detokenize(remaining_tokens[:i]) bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
ts = bs.decode("utf-8") ts = bs.decode("utf-8")
decode_success = True decode_success = True
break break
@ -1165,14 +1165,14 @@ class Llama:
} }
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "length" finish_reason = "length"
break break
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :] self._input_ids, self._scores[-1, :]
): ):
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop" finish_reason = "stop"
if self.verbose: if self.verbose:
@ -1180,7 +1180,7 @@ class Llama:
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
all_text = self.detokenize(remaining_tokens) all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
any_stop = [s for s in stop_sequences if s in all_text] any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0: if len(any_stop) > 0:
end = min(all_text.index(stop) for stop in any_stop) end = min(all_text.index(stop) for stop in any_stop)
@ -1189,7 +1189,7 @@ class Llama:
token_end_position = 0 token_end_position = 0
for token in remaining_tokens: for token in remaining_tokens:
token_end_position += len(self.detokenize([token])) token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
logprobs_or_none: Optional[CompletionLogprobs] = None logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
@ -1199,7 +1199,7 @@ class Llama:
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens]) self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
) )
token_offset = len(prompt_tokens) + returned_tokens - 1 token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self._scores[token_offset, :] logits = self._scores[token_offset, :]
@ -1313,8 +1313,8 @@ class Llama:
all_tokens = completion_tokens all_tokens = completion_tokens
all_token_strs = [ all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore") self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
for token in all_tokens for i, token in enumerate(all_tokens)
] ]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim # TODO: may be able to change this loop to use np.take_along_dim
@ -1339,7 +1339,7 @@ class Llama:
) )
token_logprobs.append(logprobs_token[int(token)]) token_logprobs.append(logprobs_token[int(token)])
top_logprob: Optional[Dict[str, float]] = { top_logprob: Optional[Dict[str, float]] = {
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs] for logprob, i in sorted_logprobs[:logprobs]
} }
top_logprob.update({token_str: logprobs_token[int(token)]}) top_logprob.update({token_str: logprobs_token[int(token)]})
@ -1594,6 +1594,8 @@ class Llama:
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
) -> Union[ ) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]: ]:

View file

@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
def tokenize( def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]: ) -> List[int]:
"""Tokenize the text into tokens.
Args:
text: The text to tokenize.
add_bos: Whether to add a beginning of sequence token.
special: Whether to tokenize text literally or as special tokens."""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def detokenize( def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes: ) -> bytes:
"""Detokenize the tokens into text.
Args:
tokens: The tokens to detokenize.
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
raise NotImplementedError raise NotImplementedError
@ -37,10 +48,7 @@ class LlamaTokenizer(BaseLlamaTokenizer):
def detokenize( def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes: ) -> bytes:
if prev_tokens is not None: return self._model.detokenize(tokens)
return self._model.detokenize(tokens[len(prev_tokens) :])
else:
return self._model.detokenize(tokens)
def encode( def encode(
self, text: str, add_bos: bool = True, special: bool = True self, text: str, add_bos: bool = True, special: bool = True
@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer):
self, tokens: List[int], prev_tokens: Optional[List[int]] = None self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes: ) -> bytes:
if prev_tokens is not None: if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode( prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )