From a8cd169251cf6c8bfef2bfc397ddb89c19f6d3d9 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 May 2023 03:15:08 -0400 Subject: [PATCH] Bugfix: Stop sequences can be strings --- llama_cpp/llama.py | 12 ++++++------ llama_cpp/server/app.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 58c32e9..da5b0e3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -602,7 +602,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: Optional[List[str]] = [], + stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -624,7 +624,7 @@ class Llama: ) text: bytes = b"" returned_tokens: int = 0 - stop = stop if stop is not None else [] + stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] model_name: str = model if model is not None else self.model_path if self.verbose: @@ -973,7 +973,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: Optional[List[str]] = [], + stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -1042,7 +1042,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: Optional[List[str]] = [], + stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -1162,7 +1162,7 @@ class Llama: top_p: float = 0.95, top_k: int = 40, stream: bool = False, - stop: Optional[List[str]] = [], + stop: Optional[Union[str, List[str]]] = [], max_tokens: int = 256, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -1188,7 +1188,7 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ - stop = stop if stop is not None else [] + stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] chat_history = "".join( f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' for message in messages diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3f95bdd..1ff0d1e 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,4 +1,5 @@ import json +import logging import multiprocessing from threading import Lock from typing import List, Optional, Union, Iterator, Dict @@ -203,7 +204,7 @@ class CreateCompletionRequest(BaseModel): default=False, description="Whether to echo the prompt in the generated text. Useful for chatbots.", ) - stop: Optional[List[str]] = stop_field + stop: Optional[Union[str, List[str]]] = stop_field stream: bool = stream_field logprobs: Optional[int] = Field( default=None,