Merge pull request #11 from jmorganca/interactive-generate

interactive generate
This commit is contained in:
Michael Yang 2023-06-28 11:32:05 -07:00 committed by GitHub
commit 5610405e77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 4 deletions

View file

@ -1,4 +1,5 @@
import os import os
import sys
import json import json
from pathlib import Path from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
@ -20,7 +21,7 @@ def main():
generate_parser = subparsers.add_parser("generate") generate_parser = subparsers.add_parser("generate")
generate_parser.add_argument("model") generate_parser.add_argument("model")
generate_parser.add_argument("prompt") generate_parser.add_argument("prompt", nargs="?")
generate_parser.set_defaults(fn=generate) generate_parser.set_defaults(fn=generate)
add_parser = subparsers.add_parser("add") add_parser = subparsers.add_parser("add")
@ -37,6 +38,8 @@ def main():
try: try:
fn = args.pop("fn") fn = args.pop("fn")
fn(**args) fn(**args)
except KeyboardInterrupt:
pass
except KeyError: except KeyError:
parser.print_help() parser.print_help()
except Exception as e: except Exception as e:
@ -49,12 +52,37 @@ def list_models(*args, **kwargs):
def generate(*args, **kwargs): def generate(*args, **kwargs):
if prompt := kwargs.get('prompt'):
print('>>>', prompt, flush=True)
print(flush=True)
generate_oneshot(*args, **kwargs)
print(flush=True)
return
return generate_interactive(*args, **kwargs)
def generate_oneshot(*args, **kwargs):
for output in engine.generate(*args, **kwargs): for output in engine.generate(*args, **kwargs):
output = json.loads(output) output = json.loads(output)
choices = output.get("choices", []) choices = output.get("choices", [])
if len(choices) > 0: if len(choices) > 0:
print(choices[0].get("text", ""), end="") print(choices[0].get("text", ""), end="", flush=True)
print()
def generate_interactive(*args, **kwargs):
print('>>> ', end='', flush=True)
for line in sys.stdin:
if not sys.stdin.isatty():
print(line, end='')
print(flush=True)
kwargs.update({'prompt': line})
generate_oneshot(*args, **kwargs)
print(flush=True)
print('>>> ', end='', flush=True)
def add(model, models_home): def add(model, models_home):

View file

@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
kwargs.update({"max_tokens": 16384}) kwargs.update({"max_tokens": 16384})
if "stop" not in kwargs: if "stop" not in kwargs:
kwargs.update({"stop": ["Q:", "\n"]}) kwargs.update({"stop": ["Q:"]})
if "stream" not in kwargs: if "stream" not in kwargs:
kwargs.update({"stream": True}) kwargs.update({"stream": True})