interactive generate
This commit is contained in:
parent
52beb0a99e
commit
db5508209b
2 changed files with 32 additions and 4 deletions
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
|
@ -20,7 +21,7 @@ def main():
|
|||
|
||||
generate_parser = subparsers.add_parser("generate")
|
||||
generate_parser.add_argument("model")
|
||||
generate_parser.add_argument("prompt")
|
||||
generate_parser.add_argument("prompt", nargs="?")
|
||||
generate_parser.set_defaults(fn=generate)
|
||||
|
||||
add_parser = subparsers.add_parser("add")
|
||||
|
@ -37,6 +38,8 @@ def main():
|
|||
try:
|
||||
fn = args.pop("fn")
|
||||
fn(**args)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except KeyError:
|
||||
parser.print_help()
|
||||
except Exception as e:
|
||||
|
@ -49,12 +52,37 @@ def list_models(*args, **kwargs):
|
|||
|
||||
|
||||
def generate(*args, **kwargs):
|
||||
if prompt := kwargs.get('prompt'):
|
||||
print('>>>', prompt, flush=True)
|
||||
print(flush=True)
|
||||
generate_oneshot(*args, **kwargs)
|
||||
print(flush=True)
|
||||
return
|
||||
|
||||
return generate_interactive(*args, **kwargs)
|
||||
|
||||
|
||||
def generate_oneshot(*args, **kwargs):
|
||||
for output in engine.generate(*args, **kwargs):
|
||||
output = json.loads(output)
|
||||
|
||||
choices = output.get("choices", [])
|
||||
if len(choices) > 0:
|
||||
print(choices[0].get("text", ""), end="")
|
||||
print(choices[0].get("text", ""), end="", flush=True)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def generate_interactive(*args, **kwargs):
|
||||
print('>>> ', end='', flush=True)
|
||||
for line in sys.stdin:
|
||||
if not sys.stdin.isatty():
|
||||
print(line, end='')
|
||||
|
||||
print(flush=True)
|
||||
kwargs.update({'prompt': line})
|
||||
generate_oneshot(*args, **kwargs)
|
||||
print(flush=True)
|
||||
print('>>> ', end='', flush=True)
|
||||
|
||||
|
||||
def add(model, models_home):
|
||||
|
|
|
@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
|||
kwargs.update({"max_tokens": 16384})
|
||||
|
||||
if "stop" not in kwargs:
|
||||
kwargs.update({"stop": ["Q:", "\n"]})
|
||||
kwargs.update({"stop": ["Q:"]})
|
||||
|
||||
if "stream" not in kwargs:
|
||||
kwargs.update({"stream": True})
|
||||
|
|
Loading…
Reference in a new issue