Dont detect off tokens, detect off detokenized utf8
This commit is contained in:
parent
3a98747026
commit
eed61289b6
1 changed files with 11 additions and 7 deletions
|
@ -459,12 +459,6 @@ class Llama:
|
|||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
# Contains multi-byte UTF8
|
||||
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||
# Bitwise AND check
|
||||
if (pattern & token == pattern):
|
||||
multibyte_fix = num - 1
|
||||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
|
@ -473,12 +467,22 @@ class Llama:
|
|||
|
||||
completion_tokens.append(token)
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
|
||||
# Contains multi-byte UTF8
|
||||
for k,char in enumerate(all_text[-3:]):
|
||||
k = 3 - k
|
||||
char = int.from_bytes(char, "big")
|
||||
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||
# Bitwise AND check
|
||||
if (num > k and pattern & char == pattern):
|
||||
multibyte_fix = num - k
|
||||
|
||||
# Stop incomplete bytes from passing
|
||||
if (multibyte_fix > 0):
|
||||
multibyte_fix -= 1
|
||||
continue
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
any_stop = [s for s in stop_sequences if s in all_text]
|
||||
if len(any_stop) > 0:
|
||||
first_stop = any_stop[0]
|
||||
|
|
Loading…
Reference in a new issue