Stream tokens instead of text chunks

This commit is contained in:
Andrei Betlen 2023-05-18 11:35:59 -04:00
parent 21d8f5fa9f
commit f0ec6e615e

View file

@ -623,7 +623,7 @@ class Llama:
b" " + prompt.encode("utf-8") b" " + prompt.encode("utf-8")
) )
text: bytes = b"" text: bytes = b""
returned_characters: int = 0 returned_tokens: int = 0
stop = stop if stop is not None else [] stop = stop if stop is not None else []
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
@ -707,33 +707,42 @@ class Llama:
break break
if stream: if stream:
start = returned_characters
longest = 0
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
# the generated text if they are part of a stop # the generated text if they are part of a stop
# sequence. # sequence.
longest = 0
for s in stop_sequences: for s in stop_sequences:
for i in range(len(s), 0, -1): for i in range(len(s), 0, -1):
if all_text.endswith(s[:i]): if all_text.endswith(s[:i]):
if i > longest: if i > longest:
longest = i longest = i
break break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:]) offset = 0
yield { remaining_tokens = completion_tokens[returned_tokens:]
"id": completion_id, remaining_length = len(self.detokenize(remaining_tokens))
"object": "text_completion", for token in remaining_tokens:
"created": created, offset += len(self.detokenize([token]))
"model": model_name, # Check if stop sequence is not in the token
"choices": [ if offset >= (remaining_length - longest - 1):
{ break
"text": text[start:].decode("utf-8", errors="ignore"), returned_tokens += 1
"index": 0, yield {
"logprobs": None, "id": completion_id,
"finish_reason": None, "object": "text_completion",
} "created": created,
], "model": model_name,
} "choices": [
{
"text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
}
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -749,22 +758,57 @@ class Llama:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
if stream: if stream:
yield { remaining_tokens = completion_tokens[returned_tokens:]
"id": completion_id, all_text = self.detokenize(remaining_tokens)
"object": "text_completion", any_stop = [s for s in stop_sequences if s in all_text]
"created": created, if len(any_stop) > 0:
"model": model_name, end = min(all_text.index(stop) for stop in any_stop)
"choices": [ else:
{ end = len(all_text)
"text": text[returned_characters:].decode(
"utf-8", errors="ignore" offset = 0
), for token in remaining_tokens:
"index": 0, offset += len(self.detokenize([token]))
"logprobs": None, if offset >= end:
"finish_reason": finish_reason, last_text = self.detokenize([token])
if offset == end - 1:
break
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": last_text[
: len(last_text) - (offset - end)
].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
}
],
} }
], break
} returned_tokens += 1
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason
if returned_tokens == len(completion_tokens)
else None,
}
],
}
return return
text_str = text.decode("utf-8", errors="ignore") text_str = text.decode("utf-8", errors="ignore")