Only support generating one prompt at a time.
This commit is contained in:
parent
8895b9002a
commit
8740ddc58e
1 changed files with 4 additions and 3 deletions
|
@ -166,10 +166,10 @@ frequency_penalty_field = Field(
|
||||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]] = Field(
|
prompt: Union[str, List[str]] = Field(
|
||||||
default="",
|
default="", description="The prompt to generate completions for."
|
||||||
description="The prompt to generate completions for."
|
|
||||||
)
|
)
|
||||||
suffix: Optional[str] = Field(
|
suffix: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -224,7 +224,8 @@ def create_completion(
|
||||||
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||||
):
|
):
|
||||||
if isinstance(request.prompt, list):
|
if isinstance(request.prompt, list):
|
||||||
request.prompt = "".join(request.prompt)
|
assert len(request.prompt) <= 1
|
||||||
|
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
|
||||||
|
|
||||||
completion_or_chunks = llama(
|
completion_or_chunks = llama(
|
||||||
**request.dict(
|
**request.dict(
|
||||||
|
|
Loading…
Reference in a new issue