commit
d28b244db3
2 changed files with 23 additions and 14 deletions
|
@ -54,15 +54,18 @@ def list_models(*args, **kwargs):
|
||||||
def generate(*args, **kwargs):
|
def generate(*args, **kwargs):
|
||||||
if prompt := kwargs.get('prompt'):
|
if prompt := kwargs.get('prompt'):
|
||||||
print('>>>', prompt, flush=True)
|
print('>>>', prompt, flush=True)
|
||||||
print(flush=True)
|
|
||||||
generate_oneshot(*args, **kwargs)
|
generate_oneshot(*args, **kwargs)
|
||||||
print(flush=True)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if sys.stdin.isatty():
|
||||||
return generate_interactive(*args, **kwargs)
|
return generate_interactive(*args, **kwargs)
|
||||||
|
|
||||||
|
return generate_batch(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def generate_oneshot(*args, **kwargs):
|
def generate_oneshot(*args, **kwargs):
|
||||||
|
print(flush=True)
|
||||||
|
|
||||||
for output in engine.generate(*args, **kwargs):
|
for output in engine.generate(*args, **kwargs):
|
||||||
output = json.loads(output)
|
output = json.loads(output)
|
||||||
choices = output.get("choices", [])
|
choices = output.get("choices", [])
|
||||||
|
@ -70,20 +73,26 @@ def generate_oneshot(*args, **kwargs):
|
||||||
print(choices[0].get("text", ""), end="", flush=True)
|
print(choices[0].get("text", ""), end="", flush=True)
|
||||||
|
|
||||||
# end with a new line
|
# end with a new line
|
||||||
print()
|
print(flush=True)
|
||||||
|
print(flush=True)
|
||||||
|
|
||||||
|
|
||||||
def generate_interactive(*args, **kwargs):
|
def generate_interactive(*args, **kwargs):
|
||||||
|
while True:
|
||||||
print('>>> ', end='', flush=True)
|
print('>>> ', end='', flush=True)
|
||||||
for line in sys.stdin:
|
line = next(sys.stdin)
|
||||||
if not sys.stdin.isatty():
|
if not line:
|
||||||
print(line, end='')
|
return
|
||||||
|
|
||||||
print(flush=True)
|
kwargs.update({"prompt": line})
|
||||||
kwargs.update({'prompt': line})
|
generate_oneshot(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_batch(*args, **kwargs):
|
||||||
|
for line in sys.stdin:
|
||||||
|
print('>>> ', line, end='', flush=True)
|
||||||
|
kwargs.update({"prompt": line})
|
||||||
generate_oneshot(*args, **kwargs)
|
generate_oneshot(*args, **kwargs)
|
||||||
print(flush=True)
|
|
||||||
print('>>> ', end='', flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def add(model, models_home):
|
def add(model, models_home):
|
||||||
|
|
|
@ -45,7 +45,7 @@ def load(model, models_home=".", llms={}):
|
||||||
|
|
||||||
if not model_path:
|
if not model_path:
|
||||||
# try loading this as a path to a model, rather than a model name
|
# try loading this as a path to a model, rather than a model name
|
||||||
model_path = model
|
model_path = os.path.abspath(model)
|
||||||
|
|
||||||
# suppress LLM's output
|
# suppress LLM's output
|
||||||
with suppress_stderr():
|
with suppress_stderr():
|
||||||
|
|
Loading…
Reference in a new issue