From 1ae91f75f0cc5e31cac22efb82ec1c95a4c80d69 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 27 Jun 2023 12:19:34 -0700 Subject: [PATCH] refactor --- ollama.py | 219 +---------------------------------------- ollama/__init__.py | 9 ++ ollama/cmd/__init__.py | 0 ollama/cmd/cli.py | 43 ++++++++ ollama/cmd/server.py | 75 ++++++++++++++ ollama/engine.py | 57 +++++++++++ ollama/model.py | 9 ++ 7 files changed, 196 insertions(+), 216 deletions(-) create mode 100644 ollama/__init__.py create mode 100644 ollama/cmd/__init__.py create mode 100644 ollama/cmd/cli.py create mode 100644 ollama/cmd/server.py create mode 100644 ollama/engine.py create mode 100644 ollama/model.py diff --git a/ollama.py b/ollama.py index d704ddd9..f06aac17 100644 --- a/ollama.py +++ b/ollama.py @@ -1,216 +1,3 @@ -import json -import os -import threading -import click -from tqdm import tqdm -from pathlib import Path -from llama_cpp import Llama -from flask import Flask, Response, stream_with_context, request -from flask_cors import CORS -from template import template - -app = Flask(__name__) -CORS(app) # enable CORS for all routes - -# llms tracks which models are loaded -llms = {} -lock = threading.Lock() - - -def models_directory(): - home_dir = Path.home() - models_dir = home_dir / ".ollama/models" - - if not models_dir.exists(): - models_dir.mkdir(parents=True) - - return models_dir - - -def load(model): - """ - Load a model. - - Args: - model (str): The name or path of the model to load. - - Returns: - str or None: The name of the model - dict or None: If the model cannot be loaded, a dictionary with an 'error' key is returned. - If the model is successfully loaded, None is returned. - """ - - with lock: - load_from = "" - if os.path.exists(model) and model.endswith(".bin"): - # model is being referenced by path rather than name directly - path = os.path.abspath(model) - base = os.path.basename(path) - - load_from = path - name = os.path.splitext(base)[0] # Split the filename and extension - else: - # model is being loaded from the ollama models directory - dir = models_directory() - - # TODO: download model from a repository if it does not exist - load_from = str(dir / f"{model}.bin") - name = model - - if load_from == "": - return None, {"error": "Model not found."} - - if not os.path.exists(load_from): - return None, {"error": f"The model {load_from} does not exist."} - - if name not in llms: - llms[name] = Llama(model_path=load_from) - - return name, None - - -def unload(model): - """ - Unload a model. - - Remove a model from the list of loaded models. If the model is not loaded, this is a no-op. - - Args: - model (str): The name of the model to unload. - """ - llms.pop(model, None) - - -def generate(model, prompt): - # auto load - name, error = load(model) - if error is not None: - return error - generated = llms[name]( - str(prompt), # TODO: optimize prompt based on model - max_tokens=4096, - stop=["Q:", "\n"], - stream=True, - ) - for output in generated: - yield json.dumps(output) - - -def models(): - dir = models_directory() - all_files = os.listdir(dir) - bin_files = [ - file.replace(".bin", "") for file in all_files if file.endswith(".bin") - ] - return bin_files - - -@app.route("/load", methods=["POST"]) -def load_route_handler(): - data = request.get_json() - model = data.get("model") - if not model: - return Response("Model is required", status=400) - error = load(model) - if error is not None: - return error - return Response(status=204) - - -@app.route("/unload", methods=["POST"]) -def unload_route_handler(): - data = request.get_json() - model = data.get("model") - if not model: - return Response("Model is required", status=400) - unload(model) - return Response(status=204) - - -@app.route("/generate", methods=["POST"]) -def generate_route_handler(): - data = request.get_json() - model = data.get("model") - prompt = data.get("prompt") - prompt = template(model, prompt) - if not model: - return Response("Model is required", status=400) - if not prompt: - return Response("Prompt is required", status=400) - if not os.path.exists(f"{model}"): - return {"error": "The model does not exist."}, 400 - return Response( - stream_with_context(generate(model, prompt)), mimetype="text/event-stream" - ) - - -@app.route("/models", methods=["GET"]) -def models_route_handler(): - bin_files = models() - return Response(json.dumps(bin_files), mimetype="application/json") - - -@click.group(invoke_without_command=True) -@click.pass_context -def cli(ctx): - # allows the script to respond to command line input when executed directly - if ctx.invoked_subcommand is None: - click.echo(ctx.get_help()) - - -@cli.command() -@click.option("--port", default=7734, help="Port to run the server on") -@click.option("--debug", default=False, help="Enable debug mode") -def serve(port, debug): - print("Serving on http://localhost:{port}") - app.run(host="0.0.0.0", port=port, debug=debug) - - -@cli.command(name="load") -@click.argument("model") -@click.option("--file", default=False, help="Indicates that a file path is provided") -def load_cli(model, file): - if file: - error = load(path=model) - else: - error = load(model) - if error is not None: - print(error) - return - print("Model loaded") - - -@cli.command(name="generate") -@click.argument("model") -@click.option("--prompt", default="", help="The prompt for the model") -def generate_cli(model, prompt): - if prompt == "": - prompt = input("Prompt: ") - output = "" - prompt = template(model, prompt) - for generated in generate(model, prompt): - generated_json = json.loads(generated) - text = generated_json["choices"][0]["text"] - output += text - print(f"\r{output}", end="", flush=True) - - -@cli.command(name="models") -def models_cli(): - print(models()) - - -@cli.command(name="pull") -@click.argument("model") -def pull_cli(model): - print("not implemented") - - -@cli.command(name="import") -@click.argument("model") -def import_cli(model): - print("not implemented") - - -if __name__ == "__main__": - cli() +from ollama.cmd.cli import main +if __name__ == '__main__': + main() diff --git a/ollama/__init__.py b/ollama/__init__.py new file mode 100644 index 00000000..384eea19 --- /dev/null +++ b/ollama/__init__.py @@ -0,0 +1,9 @@ +from ollama.model import models +from ollama.engine import generate, load, unload + +__all__ = [ + 'models', + 'generate', + 'load', + 'unload', +] diff --git a/ollama/cmd/__init__.py b/ollama/cmd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py new file mode 100644 index 00000000..7a0c03c1 --- /dev/null +++ b/ollama/cmd/cli.py @@ -0,0 +1,43 @@ +import json +from pathlib import Path +from argparse import ArgumentParser + +from ollama import model, engine +from ollama.cmd import server + + +def main(): + parser = ArgumentParser() + parser.add_argument('--models-home', default=Path.home() / '.ollama' / 'models') + + subparsers = parser.add_subparsers() + + server.set_parser(subparsers.add_parser('serve')) + + list_parser = subparsers.add_parser('list') + list_parser.set_defaults(fn=list) + + generate_parser = subparsers.add_parser('generate') + generate_parser.add_argument('model') + generate_parser.add_argument('prompt') + generate_parser.set_defaults(fn=generate) + + args = parser.parse_args() + args = vars(args) + + fn = args.pop('fn') + fn(**args) + + +def list(*args, **kwargs): + for m in model.models(*args, **kwargs): + print(m) + + +def generate(*args, **kwargs): + for output in engine.generate(*args, **kwargs): + output = json.loads(output) + + choices = output.get('choices', []) + if len(choices) > 0: + print(choices[0].get('text', ''), end='') diff --git a/ollama/cmd/server.py b/ollama/cmd/server.py new file mode 100644 index 00000000..9e2180cd --- /dev/null +++ b/ollama/cmd/server.py @@ -0,0 +1,75 @@ +from aiohttp import web + +from ollama import engine + + +def set_parser(parser): + parser.add_argument('--host', default='127.0.0.1') + parser.add_argument('--port', default=7734) + parser.set_defaults(fn=serve) + + +def serve(models_home='.', *args, **kwargs): + app = web.Application() + app.add_routes([ + web.post('/load', load), + web.post('/unload', unload), + web.post('/generate', generate), + ]) + + app.update({ + 'llms': {}, + 'models_home': models_home, + }) + + web.run_app(app, **kwargs) + + +async def load(request): + body = await request.json() + model = body.get('model') + if not model: + raise web.HTTPBadRequest() + + kwargs = { + 'llms': request.app.get('llms'), + 'models_home': request.app.get('models_home'), + } + + engine.load(model, **kwargs) + return web.Response() + + +async def unload(request): + body = await request.json() + model = body.get('model') + if not model: + raise web.HTTPBadRequest() + + engine.unload(model, llms=request.app.get('llms')) + return web.Response() + + +async def generate(request): + body = await request.json() + model = body.get('model') + if not model: + raise web.HTTPBadRequest() + + prompt = body.get('prompt') + if not prompt: + raise web.HTTPBadRequest() + + response = web.StreamResponse() + await response.prepare(request) + + kwargs = { + 'llms': request.app.get('llms'), + 'models_home': request.app.get('models_home'), + } + + for output in engine.generate(model, prompt, **kwargs): + await response.write(output.encode('utf-8')) + await response.write(b'\n') + + return response diff --git a/ollama/engine.py b/ollama/engine.py new file mode 100644 index 00000000..eb837d33 --- /dev/null +++ b/ollama/engine.py @@ -0,0 +1,57 @@ +import os +import json +import sys +from contextlib import contextmanager +from llama_cpp import Llama as LLM + +import ollama.model + + +@contextmanager +def suppress_stderr(): + stderr = os.dup(sys.stderr.fileno()) + with open(os.devnull, 'w') as devnull: + os.dup2(devnull.fileno(), sys.stderr.fileno()) + yield + + os.dup2(stderr, sys.stderr.fileno()) + + +def generate(model, prompt, models_home='.', llms={}, *args, **kwargs): + llm = load(model, models_home=models_home, llms=llms) + + if 'max_tokens' not in kwargs: + kwargs.update({'max_tokens': 16384}) + + if 'stop' not in kwargs: + kwargs.update({'stop': ['Q:', '\n']}) + + if 'stream' not in kwargs: + kwargs.update({'stream': True}) + + for output in llm(prompt, *args, **kwargs): + yield json.dumps(output) + + +def load(model, models_home='.', llms={}): + llm = llms.get(model, None) + if not llm: + model_path = { + name: path + for name, path in ollama.model.models(models_home) + }.get(model, None) + + if model_path is None: + raise ValueError('Model not found') + + # suppress LLM's output + with suppress_stderr(): + llm = LLM(model_path, verbose=False) + llms.update({model: llm}) + + return llm + + +def unload(model, llms={}): + if model in llms: + llms.pop(model) diff --git a/ollama/model.py b/ollama/model.py new file mode 100644 index 00000000..b5e7b0b8 --- /dev/null +++ b/ollama/model.py @@ -0,0 +1,9 @@ +from os import walk, path + + +def models(models_home='.', *args, **kwargs): + for root, _, files in walk(models_home): + for file in files: + base, ext = path.splitext(file) + if ext == '.bin': + yield base, path.join(root, file)