Bugfix: avoid decoding partial utf-8 characters

This commit is contained in:
Andrei Betlen 2023-03-23 16:25:13 -04:00
parent 6c596f4f5a
commit eec9256a42

View file

@ -50,10 +50,13 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
): ):
text = "" text = b""
finish_reason = "length" finish_reason = "length"
completion_tokens = 0 completion_tokens = 0
if stop is not None:
stop = [s.encode("utf-8") for s in stop]
prompt_tokens = llama_cpp.llama_tokenize( prompt_tokens = llama_cpp.llama_tokenize(
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
) )
@ -81,7 +84,8 @@ class Llama:
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
finish_reason = "stop" finish_reason = "stop"
break break
text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8") # text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
text += llama_cpp.llama_token_to_str(self.ctx, token)
self.tokens[prompt_tokens + i] = token self.tokens[prompt_tokens + i] = token
completion_tokens += 1 completion_tokens += 1
@ -100,6 +104,8 @@ class Llama:
self.n_threads, self.n_threads,
) )
text = text.decode("utf-8")
if echo: if echo:
text = prompt + text text = prompt + text
@ -111,6 +117,7 @@ class Llama:
self.ctx, self.ctx,
)[:logprobs] )[:logprobs]
return { return {
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change "id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
"object": "text_completion", "object": "text_completion",