fix: LlamaHFTokenizer now receives pre_tokens
This commit is contained in:
parent
ded5d627a5
commit
47bad30dd7
2 changed files with 33 additions and 23 deletions
|
@ -480,7 +480,7 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
The detokenized string.
|
The detokenized string.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer_.detokenize(tokens, prev_tokens)
|
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
|
||||||
|
|
||||||
def set_cache(self, cache: Optional[BaseLlamaCache]):
|
def set_cache(self, cache: Optional[BaseLlamaCache]):
|
||||||
"""Set the cache.
|
"""Set the cache.
|
||||||
|
@ -1016,13 +1016,13 @@ class Llama:
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
):
|
):
|
||||||
if token == self._token_eos:
|
if token == self._token_eos:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||||
|
|
||||||
# Contains multi-byte UTF8
|
# Contains multi-byte UTF8
|
||||||
for k, char in enumerate(all_text[-3:]):
|
for k, char in enumerate(all_text[-3:]):
|
||||||
|
@ -1046,7 +1046,7 @@ class Llama:
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
remaining_tokens = completion_tokens[returned_tokens:]
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
remaining_text = self.detokenize(remaining_tokens)
|
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||||
remaining_length = len(remaining_text)
|
remaining_length = len(remaining_text)
|
||||||
|
|
||||||
# We want to avoid yielding any characters from
|
# We want to avoid yielding any characters from
|
||||||
|
@ -1068,17 +1068,17 @@ class Llama:
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
if token == self.token_bos():
|
if token == self.token_bos():
|
||||||
continue
|
continue
|
||||||
token_end_position += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
||||||
# Check if stop sequence is in the token
|
# Check if stop sequence is in the token
|
||||||
if token_end_position > (
|
if token_end_position > (
|
||||||
remaining_length - first_stop_position
|
remaining_length - first_stop_position
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
token_str = self.detokenize([token]).decode(
|
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
text_offset = len(prompt) + len(
|
text_offset = len(prompt) + len(
|
||||||
self.detokenize(completion_tokens[:returned_tokens]).decode(
|
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1100,7 +1100,7 @@ class Llama:
|
||||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||||
logprobs_or_none = {
|
logprobs_or_none = {
|
||||||
"tokens": [
|
"tokens": [
|
||||||
self.detokenize([token]).decode(
|
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -1116,7 +1116,7 @@ class Llama:
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": self.detokenize([token]).decode(
|
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
),
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
@ -1130,7 +1130,7 @@ class Llama:
|
||||||
decode_success = False
|
decode_success = False
|
||||||
for i in range(1, len(remaining_tokens) + 1):
|
for i in range(1, len(remaining_tokens) + 1):
|
||||||
try:
|
try:
|
||||||
bs = self.detokenize(remaining_tokens[:i])
|
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||||
ts = bs.decode("utf-8")
|
ts = bs.decode("utf-8")
|
||||||
decode_success = True
|
decode_success = True
|
||||||
break
|
break
|
||||||
|
@ -1165,14 +1165,14 @@ class Llama:
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(completion_tokens) >= max_tokens:
|
if len(completion_tokens) >= max_tokens:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
break
|
break
|
||||||
|
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
self._input_ids, self._scores[-1, :]
|
self._input_ids, self._scores[-1, :]
|
||||||
):
|
):
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -1180,7 +1180,7 @@ class Llama:
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
remaining_tokens = completion_tokens[returned_tokens:]
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
all_text = self.detokenize(remaining_tokens)
|
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||||
any_stop = [s for s in stop_sequences if s in all_text]
|
any_stop = [s for s in stop_sequences if s in all_text]
|
||||||
if len(any_stop) > 0:
|
if len(any_stop) > 0:
|
||||||
end = min(all_text.index(stop) for stop in any_stop)
|
end = min(all_text.index(stop) for stop in any_stop)
|
||||||
|
@ -1189,7 +1189,7 @@ class Llama:
|
||||||
|
|
||||||
token_end_position = 0
|
token_end_position = 0
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
token_end_position += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
|
||||||
|
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
|
@ -1199,7 +1199,7 @@ class Llama:
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
text_offset = len(prompt) + len(
|
text_offset = len(prompt) + len(
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
logits = self._scores[token_offset, :]
|
logits = self._scores[token_offset, :]
|
||||||
|
@ -1313,8 +1313,8 @@ class Llama:
|
||||||
all_tokens = completion_tokens
|
all_tokens = completion_tokens
|
||||||
|
|
||||||
all_token_strs = [
|
all_token_strs = [
|
||||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
|
||||||
for token in all_tokens
|
for i, token in enumerate(all_tokens)
|
||||||
]
|
]
|
||||||
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
||||||
# TODO: may be able to change this loop to use np.take_along_dim
|
# TODO: may be able to change this loop to use np.take_along_dim
|
||||||
|
@ -1339,7 +1339,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_logprobs.append(logprobs_token[int(token)])
|
token_logprobs.append(logprobs_token[int(token)])
|
||||||
top_logprob: Optional[Dict[str, float]] = {
|
top_logprob: Optional[Dict[str, float]] = {
|
||||||
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
|
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||||
|
@ -1594,6 +1594,8 @@ class Llama:
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
grammar: Optional[LlamaGrammar] = None,
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
||||||
]:
|
]:
|
||||||
|
|
|
@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self, text: bytes, add_bos: bool = True, special: bool = True
|
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
"""Tokenize the text into tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to tokenize.
|
||||||
|
add_bos: Whether to add a beginning of sequence token.
|
||||||
|
special: Whether to tokenize text literally or as special tokens."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def detokenize(
|
def detokenize(
|
||||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
|
"""Detokenize the tokens into text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: The tokens to detokenize.
|
||||||
|
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,9 +48,6 @@ class LlamaTokenizer(BaseLlamaTokenizer):
|
||||||
def detokenize(
|
def detokenize(
|
||||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
if prev_tokens is not None:
|
|
||||||
return self._model.detokenize(tokens[len(prev_tokens) :])
|
|
||||||
else:
|
|
||||||
return self._model.detokenize(tokens)
|
return self._model.detokenize(tokens)
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
|
@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer):
|
||||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
if prev_tokens is not None:
|
if prev_tokens is not None:
|
||||||
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
|
||||||
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
|
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue