Handle prompt list

This commit is contained in:
Andrei Betlen 2023-04-06 21:07:35 -04:00
parent 38f7dea6ca
commit 55279b679d

View file

@ -60,7 +60,7 @@ llama = llama_cpp.Llama(
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
prompt: str prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None) suffix: Optional[str] = Field(None)
max_tokens: int = 16 max_tokens: int = 16
temperature: float = 0.8 temperature: float = 0.8
@ -100,10 +100,10 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion(request: CreateCompletionRequest): def create_completion(request: CreateCompletionRequest):
if request.stream: if isinstance(request.prompt, list):
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore request.prompt = "".join(request.prompt)
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama( completion_or_chunks = llama(
**request.dict( **request.dict(
exclude={ exclude={
"model", "model",
@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
} }
) )
) )
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
@ -259,4 +264,6 @@ if __name__ == "__main__":
import os import os
import uvicorn import uvicorn
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))) uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
)