Stream tokens instead of text chunks
This commit is contained in:
parent
21d8f5fa9f
commit
f0ec6e615e
1 changed files with 78 additions and 34 deletions
|
@ -623,7 +623,7 @@ class Llama:
|
||||||
b" " + prompt.encode("utf-8")
|
b" " + prompt.encode("utf-8")
|
||||||
)
|
)
|
||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_characters: int = 0
|
returned_tokens: int = 0
|
||||||
stop = stop if stop is not None else []
|
stop = stop if stop is not None else []
|
||||||
model_name: str = model if model is not None else self.model_path
|
model_name: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
|
@ -707,19 +707,26 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
start = returned_characters
|
|
||||||
longest = 0
|
|
||||||
# We want to avoid yielding any characters from
|
# We want to avoid yielding any characters from
|
||||||
# the generated text if they are part of a stop
|
# the generated text if they are part of a stop
|
||||||
# sequence.
|
# sequence.
|
||||||
|
longest = 0
|
||||||
for s in stop_sequences:
|
for s in stop_sequences:
|
||||||
for i in range(len(s), 0, -1):
|
for i in range(len(s), 0, -1):
|
||||||
if all_text.endswith(s[:i]):
|
if all_text.endswith(s[:i]):
|
||||||
if i > longest:
|
if i > longest:
|
||||||
longest = i
|
longest = i
|
||||||
break
|
break
|
||||||
text = all_text[: len(all_text) - longest]
|
|
||||||
returned_characters += len(text[start:])
|
offset = 0
|
||||||
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
|
remaining_length = len(self.detokenize(remaining_tokens))
|
||||||
|
for token in remaining_tokens:
|
||||||
|
offset += len(self.detokenize([token]))
|
||||||
|
# Check if stop sequence is not in the token
|
||||||
|
if offset >= (remaining_length - longest - 1):
|
||||||
|
break
|
||||||
|
returned_tokens += 1
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -727,7 +734,9 @@ class Llama:
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[start:].decode("utf-8", errors="ignore"),
|
"text": self.detokenize([token]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
|
@ -749,6 +758,21 @@ class Llama:
|
||||||
llama_cpp.llama_print_timings(self.ctx)
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
|
all_text = self.detokenize(remaining_tokens)
|
||||||
|
any_stop = [s for s in stop_sequences if s in all_text]
|
||||||
|
if len(any_stop) > 0:
|
||||||
|
end = min(all_text.index(stop) for stop in any_stop)
|
||||||
|
else:
|
||||||
|
end = len(all_text)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for token in remaining_tokens:
|
||||||
|
offset += len(self.detokenize([token]))
|
||||||
|
if offset >= end:
|
||||||
|
last_text = self.detokenize([token])
|
||||||
|
if offset == end - 1:
|
||||||
|
break
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -756,12 +780,32 @@ class Llama:
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[returned_characters:].decode(
|
"text": last_text[
|
||||||
|
: len(last_text) - (offset - end)
|
||||||
|
].decode("utf-8", errors="ignore"),
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
break
|
||||||
|
returned_tokens += 1
|
||||||
|
yield {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": self.detokenize([token]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
),
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason
|
||||||
|
if returned_tokens == len(completion_tokens)
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue