Handle prompt list
This commit is contained in:
parent
38f7dea6ca
commit
55279b679d
1 changed files with 13 additions and 6 deletions
|
@ -60,7 +60,7 @@ llama = llama_cpp.Llama(
|
|||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: str
|
||||
prompt: Union[str, List[str]]
|
||||
suffix: Optional[str] = Field(None)
|
||||
max_tokens: int = 16
|
||||
temperature: float = 0.8
|
||||
|
@ -100,10 +100,10 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
|||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
def create_completion(request: CreateCompletionRequest):
|
||||
if request.stream:
|
||||
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
|
||||
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
||||
return llama(
|
||||
if isinstance(request.prompt, list):
|
||||
request.prompt = "".join(request.prompt)
|
||||
|
||||
completion_or_chunks = llama(
|
||||
**request.dict(
|
||||
exclude={
|
||||
"model",
|
||||
|
@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
|
|||
}
|
||||
)
|
||||
)
|
||||
if request.stream:
|
||||
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
||||
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
|
||||
return completion
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
|
@ -259,4 +264,6 @@ if __name__ == "__main__":
|
|||
import os
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)))
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue