Bugfix: avoid decoding partial utf-8 characters

This commit is contained in:
Andrei Betlen 2023-03-23 16:25:13 -04:00
parent 6c596f4f5a
commit eec9256a42

View file

@ -50,10 +50,13 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
):
text = ""
text = b""
finish_reason = "length"
completion_tokens = 0
if stop is not None:
stop = [s.encode("utf-8") for s in stop]
prompt_tokens = llama_cpp.llama_tokenize(
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
)
@ -81,7 +84,8 @@ class Llama:
if token == llama_cpp.llama_token_eos():
finish_reason = "stop"
break
text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
# text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
text += llama_cpp.llama_token_to_str(self.ctx, token)
self.tokens[prompt_tokens + i] = token
completion_tokens += 1
@ -100,6 +104,8 @@ class Llama:
self.n_threads,
)
text = text.decode("utf-8")
if echo:
text = prompt + text
@ -111,6 +117,7 @@ class Llama:
self.ctx,
)[:logprobs]
return {
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
"object": "text_completion",