Add functions parameters

This commit is contained in:
Andrei Betlen 2023-07-19 03:48:20 -04:00
parent 57db1f9570
commit b43917c144
3 changed files with 20 additions and 0 deletions

View file

@ -1435,6 +1435,8 @@ class Llama:
def create_chat_completion( def create_chat_completion(
self, self,
messages: List[ChatCompletionMessage], messages: List[ChatCompletionMessage],
functions: Optional[List[ChatCompletionFunction]] = None,
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,

View file

@ -63,6 +63,16 @@ class ChatCompletionMessage(TypedDict):
user: NotRequired[str] user: NotRequired[str]
class ChatCompletionFunction(TypedDict):
name: str
description: NotRequired[str]
parameters: Dict[str, Any] # TODO: make this more specific
class ChatCompletionFunctionCall(TypedDict):
name: str
class ChatCompletionChoice(TypedDict): class ChatCompletionChoice(TypedDict):
index: int index: int
message: ChatCompletionMessage message: ChatCompletionMessage

View file

@ -446,6 +446,14 @@ class CreateChatCompletionRequest(BaseModel):
messages: List[ChatCompletionRequestMessage] = Field( messages: List[ChatCompletionRequestMessage] = Field(
default=[], description="A list of messages to generate completions for." default=[], description="A list of messages to generate completions for."
) )
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
default=None,
description="A list of functions to apply to the generated completions.",
)
function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field(
default=None,
description="A function to apply to the generated completions.",
)
max_tokens: int = max_tokens_field max_tokens: int = max_tokens_field
temperature: float = temperature_field temperature: float = temperature_field
top_p: float = top_p_field top_p: float = top_p_field