Only support generating one prompt at a time.

This commit is contained in:
Andrei Betlen 2023-05-12 07:21:46 -04:00
parent 8895b9002a
commit 8740ddc58e

View file

@ -166,10 +166,10 @@ frequency_penalty_field = Field(
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
) )
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field( prompt: Union[str, List[str]] = Field(
default="", default="", description="The prompt to generate completions for."
description="The prompt to generate completions for."
) )
suffix: Optional[str] = Field( suffix: Optional[str] = Field(
default=None, default=None,
@ -224,7 +224,8 @@ def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt) assert len(request.prompt) <= 1
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
completion_or_chunks = llama( completion_or_chunks = llama(
**request.dict( **request.dict(