Merge pull request #15 from jmorganca/batch

batch model
This commit is contained in:
Michael Yang 2023-06-28 17:10:39 -07:00 committed by GitHub
commit d28b244db3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 14 deletions

View file

@ -54,15 +54,18 @@ def list_models(*args, **kwargs):
def generate(*args, **kwargs): def generate(*args, **kwargs):
if prompt := kwargs.get('prompt'): if prompt := kwargs.get('prompt'):
print('>>>', prompt, flush=True) print('>>>', prompt, flush=True)
print(flush=True)
generate_oneshot(*args, **kwargs) generate_oneshot(*args, **kwargs)
print(flush=True)
return return
if sys.stdin.isatty():
return generate_interactive(*args, **kwargs) return generate_interactive(*args, **kwargs)
return generate_batch(*args, **kwargs)
def generate_oneshot(*args, **kwargs): def generate_oneshot(*args, **kwargs):
print(flush=True)
for output in engine.generate(*args, **kwargs): for output in engine.generate(*args, **kwargs):
output = json.loads(output) output = json.loads(output)
choices = output.get("choices", []) choices = output.get("choices", [])
@ -70,20 +73,26 @@ def generate_oneshot(*args, **kwargs):
print(choices[0].get("text", ""), end="", flush=True) print(choices[0].get("text", ""), end="", flush=True)
# end with a new line # end with a new line
print() print(flush=True)
print(flush=True)
def generate_interactive(*args, **kwargs): def generate_interactive(*args, **kwargs):
while True:
print('>>> ', end='', flush=True) print('>>> ', end='', flush=True)
for line in sys.stdin: line = next(sys.stdin)
if not sys.stdin.isatty(): if not line:
print(line, end='') return
print(flush=True) kwargs.update({"prompt": line})
kwargs.update({'prompt': line}) generate_oneshot(*args, **kwargs)
def generate_batch(*args, **kwargs):
for line in sys.stdin:
print('>>> ', line, end='', flush=True)
kwargs.update({"prompt": line})
generate_oneshot(*args, **kwargs) generate_oneshot(*args, **kwargs)
print(flush=True)
print('>>> ', end='', flush=True)
def add(model, models_home): def add(model, models_home):

View file

@ -45,7 +45,7 @@ def load(model, models_home=".", llms={}):
if not model_path: if not model_path:
# try loading this as a path to a model, rather than a model name # try loading this as a path to a model, rather than a model name
model_path = model model_path = os.path.abspath(model)
# suppress LLM's output # suppress LLM's output
with suppress_stderr(): with suppress_stderr():