From 55279b679df4153759a80945af7017a79a8ac37c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 6 Apr 2023 21:07:35 -0400 Subject: [PATCH] Handle prompt list --- llama_cpp/server/__main__.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 0362cff..0650bc0 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -60,7 +60,7 @@ llama = llama_cpp.Llama( class CreateCompletionRequest(BaseModel): - prompt: str + prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) max_tokens: int = 16 temperature: float = 0.8 @@ -100,10 +100,10 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) response_model=CreateCompletionResponse, ) def create_completion(request: CreateCompletionRequest): - if request.stream: - chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore - return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) - return llama( + if isinstance(request.prompt, list): + request.prompt = "".join(request.prompt) + + completion_or_chunks = llama( **request.dict( exclude={ "model", @@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest): } ) ) + if request.stream: + chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore + return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) + completion: llama_cpp.Completion = completion_or_chunks # type: ignore + return completion class CreateEmbeddingRequest(BaseModel): @@ -259,4 +264,6 @@ if __name__ == "__main__": import os import uvicorn - uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))) + uvicorn.run( + app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) + )