From 3dec778c900852afda66c0a9221db66f3279fe96 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 3 Apr 2023 20:12:14 -0400 Subject: [PATCH] Update to more sensible return signature --- llama_cpp/llama.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 87f69ce..332cef9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -2,7 +2,7 @@ import os import uuid import time import multiprocessing -from typing import List, Optional, Union, Generator, Sequence +from typing import List, Optional, Union, Generator, Sequence, Iterator from collections import deque from . import llama_cpp @@ -286,10 +286,7 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, - ) -> Union[ - Generator[Completion, None, None], - Generator[CompletionChunk, None, None], - ]: + ) -> Union[Iterator[Completion], Iterator[CompletionChunk],]: assert self.ctx is not None completion_id = f"cmpl-{str(uuid.uuid4())}" created = int(time.time()) @@ -428,7 +425,7 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, - ) -> Union[Completion, Generator[CompletionChunk, None, None]]: + ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. Args: @@ -465,7 +462,7 @@ class Llama: stream=stream, ) if stream: - chunks: Generator[CompletionChunk, None, None] = completion_or_chunks + chunks: Iterator[CompletionChunk] = completion_or_chunks return chunks completion: Completion = next(completion_or_chunks) # type: ignore return completion