Update to more sensible return signature

This commit is contained in:
Andrei Betlen 2023-04-03 20:12:14 -04:00
parent f7ab8d55b2
commit 3dec778c90

View file

@ -2,7 +2,7 @@ import os
import uuid
import time
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence
from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque
from . import llama_cpp
@ -286,10 +286,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[
Generator[Completion, None, None],
Generator[CompletionChunk, None, None],
]:
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time())
@ -428,7 +425,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
Args:
@ -465,7 +462,7 @@ class Llama:
stream=stream,
)
if stream:
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
chunks: Iterator[CompletionChunk] = completion_or_chunks
return chunks
completion: Completion = next(completion_or_chunks) # type: ignore
return completion