Update to more sensible return signature
This commit is contained in:
parent
f7ab8d55b2
commit
3dec778c90
1 changed files with 4 additions and 7 deletions
|
@ -2,7 +2,7 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence
|
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
|
@ -286,10 +286,7 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
|
||||||
Generator[Completion, None, None],
|
|
||||||
Generator[CompletionChunk, None, None],
|
|
||||||
]:
|
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
|
@ -428,7 +425,7 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -465,7 +462,7 @@ class Llama:
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
return chunks
|
return chunks
|
||||||
completion: Completion = next(completion_or_chunks) # type: ignore
|
completion: Completion = next(completion_or_chunks) # type: ignore
|
||||||
return completion
|
return completion
|
||||||
|
|
Loading…
Reference in a new issue