From db5508209bb7d89c1f68bd4f5ff5e97e2b7810df Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 28 Jun 2023 11:21:05 -0700 Subject: [PATCH] interactive generate --- ollama/cmd/cli.py | 34 +++++++++++++++++++++++++++++++--- ollama/engine.py | 2 +- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index f47cdcd2..f79c218c 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -1,4 +1,5 @@ import os +import sys import json from pathlib import Path from argparse import ArgumentParser @@ -20,7 +21,7 @@ def main(): generate_parser = subparsers.add_parser("generate") generate_parser.add_argument("model") - generate_parser.add_argument("prompt") + generate_parser.add_argument("prompt", nargs="?") generate_parser.set_defaults(fn=generate) add_parser = subparsers.add_parser("add") @@ -37,6 +38,8 @@ def main(): try: fn = args.pop("fn") fn(**args) + except KeyboardInterrupt: + pass except KeyError: parser.print_help() except Exception as e: @@ -49,12 +52,37 @@ def list_models(*args, **kwargs): def generate(*args, **kwargs): + if prompt := kwargs.get('prompt'): + print('>>>', prompt, flush=True) + print(flush=True) + generate_oneshot(*args, **kwargs) + print(flush=True) + return + + return generate_interactive(*args, **kwargs) + + +def generate_oneshot(*args, **kwargs): for output in engine.generate(*args, **kwargs): output = json.loads(output) - choices = output.get("choices", []) if len(choices) > 0: - print(choices[0].get("text", ""), end="") + print(choices[0].get("text", ""), end="", flush=True) + + print() + + +def generate_interactive(*args, **kwargs): + print('>>> ', end='', flush=True) + for line in sys.stdin: + if not sys.stdin.isatty(): + print(line, end='') + + print(flush=True) + kwargs.update({'prompt': line}) + generate_oneshot(*args, **kwargs) + print(flush=True) + print('>>> ', end='', flush=True) def add(model, models_home): diff --git a/ollama/engine.py b/ollama/engine.py index 5525c410..4d80234f 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): kwargs.update({"max_tokens": 16384}) if "stop" not in kwargs: - kwargs.update({"stop": ["Q:", "\n"]}) + kwargs.update({"stop": ["Q:"]}) if "stream" not in kwargs: kwargs.update({"stream": True})