Update to more sensible return signature

This commit is contained in:
Andrei Betlen 2023-04-03 20:12:14 -04:00
parent f7ab8d55b2
commit 3dec778c90

View file

@ -2,7 +2,7 @@ import os
import uuid import uuid
import time import time
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque from collections import deque
from . import llama_cpp from . import llama_cpp
@ -286,10 +286,7 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
) -> Union[ ) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
Generator[Completion, None, None],
Generator[CompletionChunk, None, None],
]:
assert self.ctx is not None assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}" completion_id = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time()) created = int(time.time())
@ -428,7 +425,7 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
) -> Union[Completion, Generator[CompletionChunk, None, None]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
Args: Args:
@ -465,7 +462,7 @@ class Llama:
stream=stream, stream=stream,
) )
if stream: if stream:
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
return chunks return chunks
completion: Completion = next(completion_or_chunks) # type: ignore completion: Completion = next(completion_or_chunks) # type: ignore
return completion return completion