diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f8e0527..d7d3e85 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1060,6 +1060,20 @@ class Llama: ].decode("utf-8", errors="ignore"), "index": 0, "logprobs": logprobs_or_none, + "finish_reason": None, + } + ], + } + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": "", + "index": 0, + "logprobs": None, "finish_reason": finish_reason, } ], @@ -1078,9 +1092,21 @@ class Llama: ), "index": 0, "logprobs": logprobs_or_none, - "finish_reason": finish_reason - if returned_tokens == len(completion_tokens) - else None, + "finish_reason": None, + } + ], + } + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": "", + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, } ], } @@ -1370,7 +1396,9 @@ class Llama: "index": 0, "delta": { "content": chunk["choices"][0]["text"], - }, + } + if chunk["choices"][0]["finish_reason"] is None + else {}, "finish_reason": chunk["choices"][0]["finish_reason"], } ], diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 7729ced..6ba8023 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict +from typing import Any, List, Optional, Dict, Union from typing_extensions import TypedDict, NotRequired, Literal @@ -77,6 +77,8 @@ class ChatCompletion(TypedDict): choices: List[ChatCompletionChoice] usage: CompletionUsage +class ChatCompletionChunkDeltaEmpty(TypedDict): + pass class ChatCompletionChunkDelta(TypedDict): role: NotRequired[Literal["assistant"]] @@ -85,7 +87,7 @@ class ChatCompletionChunkDelta(TypedDict): class ChatCompletionChunkChoice(TypedDict): index: int - delta: ChatCompletionChunkDelta + delta: Union[ChatCompletionChunkDelta, ChatCompletionChunkDeltaEmpty] finish_reason: Optional[str]