diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b23280d..c573add 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1435,6 +1435,8 @@ class Llama: def create_chat_completion( self, messages: List[ChatCompletionMessage], + functions: Optional[List[ChatCompletionFunction]] = None, + function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 6ba8023..19f9452 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -63,6 +63,16 @@ class ChatCompletionMessage(TypedDict): user: NotRequired[str] +class ChatCompletionFunction(TypedDict): + name: str + description: NotRequired[str] + parameters: Dict[str, Any] # TODO: make this more specific + + +class ChatCompletionFunctionCall(TypedDict): + name: str + + class ChatCompletionChoice(TypedDict): index: int message: ChatCompletionMessage diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 8a9b818..36763f2 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -446,6 +446,14 @@ class CreateChatCompletionRequest(BaseModel): messages: List[ChatCompletionRequestMessage] = Field( default=[], description="A list of messages to generate completions for." ) + functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( + default=None, + description="A list of functions to apply to the generated completions.", + ) + function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field( + default=None, + description="A function to apply to the generated completions.", + ) max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field