Handle prompt list
This commit is contained in:
parent
38f7dea6ca
commit
55279b679d
1 changed files with 13 additions and 6 deletions
|
@ -60,7 +60,7 @@ llama = llama_cpp.Llama(
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: str
|
prompt: Union[str, List[str]]
|
||||||
suffix: Optional[str] = Field(None)
|
suffix: Optional[str] = Field(None)
|
||||||
max_tokens: int = 16
|
max_tokens: int = 16
|
||||||
temperature: float = 0.8
|
temperature: float = 0.8
|
||||||
|
@ -100,10 +100,10 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||||
response_model=CreateCompletionResponse,
|
response_model=CreateCompletionResponse,
|
||||||
)
|
)
|
||||||
def create_completion(request: CreateCompletionRequest):
|
def create_completion(request: CreateCompletionRequest):
|
||||||
if request.stream:
|
if isinstance(request.prompt, list):
|
||||||
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
|
request.prompt = "".join(request.prompt)
|
||||||
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
|
||||||
return llama(
|
completion_or_chunks = llama(
|
||||||
**request.dict(
|
**request.dict(
|
||||||
exclude={
|
exclude={
|
||||||
"model",
|
"model",
|
||||||
|
@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if request.stream:
|
||||||
|
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
||||||
|
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
|
||||||
|
return completion
|
||||||
|
|
||||||
|
|
||||||
class CreateEmbeddingRequest(BaseModel):
|
class CreateEmbeddingRequest(BaseModel):
|
||||||
|
@ -259,4 +264,6 @@ if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)))
|
uvicorn.run(
|
||||||
|
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue