Bugfix: Stop sequences can be strings

This commit is contained in:
Andrei Betlen 2023-05-19 03:15:08 -04:00
parent f0812c4d8c
commit a8cd169251
2 changed files with 8 additions and 7 deletions

View file

@ -602,7 +602,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -624,7 +624,7 @@ class Llama:
) )
text: bytes = b"" text: bytes = b""
returned_tokens: int = 0 returned_tokens: int = 0
stop = stop if stop is not None else [] stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
if self.verbose: if self.verbose:
@ -973,7 +973,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -1042,7 +1042,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -1162,7 +1162,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: Optional[List[str]] = [], stop: Optional[Union[str, List[str]]] = [],
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -1188,7 +1188,7 @@ class Llama:
Returns: Returns:
Generated chat completion or a stream of chat completion chunks. Generated chat completion or a stream of chat completion chunks.
""" """
stop = stop if stop is not None else [] stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
chat_history = "".join( chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages for message in messages

View file

@ -1,4 +1,5 @@
import json import json
import logging
import multiprocessing import multiprocessing
from threading import Lock from threading import Lock
from typing import List, Optional, Union, Iterator, Dict from typing import List, Optional, Union, Iterator, Dict
@ -203,7 +204,7 @@ class CreateCompletionRequest(BaseModel):
default=False, default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.", description="Whether to echo the prompt in the generated text. Useful for chatbots.",
) )
stop: Optional[List[str]] = stop_field stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field stream: bool = stream_field
logprobs: Optional[int] = Field( logprobs: Optional[int] = Field(
default=None, default=None,