Handle prompt list

This commit is contained in:
Andrei Betlen 2023-04-06 21:07:35 -04:00
parent 38f7dea6ca
commit 55279b679d

View file

@ -60,7 +60,7 @@ llama = llama_cpp.Llama(
class CreateCompletionRequest(BaseModel):
prompt: str
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
@ -100,10 +100,10 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama(
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
}
)
)
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel):
@ -259,4 +264,6 @@ if __name__ == "__main__":
import os
import uvicorn
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)))
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
)