Bugfix: Stop sequences can be strings

This commit is contained in:
Andrei Betlen 2023-05-19 03:15:08 -04:00
parent f0812c4d8c
commit a8cd169251
2 changed files with 8 additions and 7 deletions

View file

@ -602,7 +602,7 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
@ -624,7 +624,7 @@ class Llama:
)
text: bytes = b""
returned_tokens: int = 0
stop = stop if stop is not None else []
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
model_name: str = model if model is not None else self.model_path
if self.verbose:
@ -973,7 +973,7 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
@ -1042,7 +1042,7 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
@ -1162,7 +1162,7 @@ class Llama:
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: Optional[List[str]] = [],
stop: Optional[Union[str, List[str]]] = [],
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -1188,7 +1188,7 @@ class Llama:
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
stop = stop if stop is not None else []
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages

View file

@ -1,4 +1,5 @@
import json
import logging
import multiprocessing
from threading import Lock
from typing import List, Optional, Union, Iterator, Dict
@ -203,7 +204,7 @@ class CreateCompletionRequest(BaseModel):
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
)
stop: Optional[List[str]] = stop_field
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
logprobs: Optional[int] = Field(
default=None,