diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 640dd3f..5d87e78 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -66,6 +66,10 @@ def get_llama(): with llama_lock: yield llama +model_field = Field( + description="The model to use for generating completions." +) + class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) @@ -76,8 +80,9 @@ class CreateCompletionRequest(BaseModel): stop: Optional[List[str]] = [] stream: bool = False - # ignored or currently unsupported - model: Optional[str] = Field(None) + # ignored, but marked as required for the sake of compatibility with openai's api + model: str = model_field + n: Optional[int] = 1 logprobs: Optional[int] = Field(None) presence_penalty: Optional[float] = 0 @@ -133,7 +138,8 @@ def create_completion( class CreateEmbeddingRequest(BaseModel): - model: Optional[str] + # ignored, but marked as required for the sake of compatibility with openai's api + model: str = model_field input: str user: Optional[str] @@ -173,8 +179,9 @@ class CreateChatCompletionRequest(BaseModel): stop: Optional[List[str]] = [] max_tokens: int = 128 - # ignored or currently unsupported - model: Optional[str] = Field(None) + # ignored, but marked as required for the sake of compatibility with openai's api + model: str = model_field + n: Optional[int] = 1 presence_penalty: Optional[float] = 0 frequency_penalty: Optional[float] = 0