From 5cea13ce007755cf3a1135989d7df87573bffacd Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 28 Jun 2023 12:25:08 -0700 Subject: [PATCH] rename model to model_name to avoid collision --- ollama/cmd/cli.py | 4 ++-- ollama/cmd/server.py | 24 ++++++++++++------------ ollama/engine.py | 43 +++++++++++++++++++------------------------ ollama/prompt.py | 8 ++++---- 4 files changed, 37 insertions(+), 42 deletions(-) diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index f2b53f5d..a7f0f6c1 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -111,7 +111,7 @@ def generate_oneshot(*args, **kwargs): spinner.start() spinner_running = True try: - for output in engine.generate(*args, **kwargs): + for output in engine.generate(model_name=kwargs.pop('model'), *args, **kwargs): choices = output.get("choices", []) if len(choices) > 0: if spinner_running: @@ -147,7 +147,7 @@ def generate_batch(*args, **kwargs): def pull(*args, **kwargs): - model.pull(*args, **kwargs) + model.pull(model_name=kwargs.pop('model'), *args, **kwargs) def run(*args, **kwargs): diff --git a/ollama/cmd/server.py b/ollama/cmd/server.py index fe803c17..d634babe 100644 --- a/ollama/cmd/server.py +++ b/ollama/cmd/server.py @@ -38,7 +38,7 @@ def serve(*args, **kwargs): app.update( { - "llms": {}, + "models": {}, } ) @@ -47,32 +47,32 @@ def serve(*args, **kwargs): async def load(request): body = await request.json() - model = body.get("model") - if not model: + name = body.get("model") + if not name: raise web.HTTPBadRequest() kwargs = { - "llms": request.app.get("llms"), + "models": request.app.get("models"), } - engine.load(model, **kwargs) + engine.load(name, **kwargs) return web.Response() async def unload(request): body = await request.json() - model = body.get("model") - if not model: + name = body.get("model") + if not name: raise web.HTTPBadRequest() - engine.unload(model, llms=request.app.get("llms")) + engine.unload(name, models=request.app.get("models")) return web.Response() async def generate(request): body = await request.json() - model = body.get("model") - if not model: + name = body.get("model") + if not name: raise web.HTTPBadRequest() prompt = body.get("prompt") @@ -83,10 +83,10 @@ async def generate(request): await response.prepare(request) kwargs = { - "llms": request.app.get("llms"), + "models": request.app.get("models"), } - for output in engine.generate(model, prompt, **kwargs): + for output in engine.generate(name, prompt, **kwargs): output = json.dumps(output).encode('utf-8') await response.write(output) await response.write(b"\n") diff --git a/ollama/engine.py b/ollama/engine.py index 67c5cce9..aa82336f 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -4,8 +4,8 @@ from os import path from contextlib import contextmanager from llama_cpp import Llama as LLM -import ollama.model import ollama.prompt +from ollama.model import models_home @contextmanager @@ -18,10 +18,7 @@ def suppress_stderr(): os.dup2(stderr, sys.stderr.fileno()) -def generate(model, prompt, llms={}, *args, **kwargs): - llm = load(model, llms=llms) - - prompt = ollama.prompt.template(model, prompt) +def generate(model_name, prompt, models={}, *args, **kwargs): if "max_tokens" not in kwargs: kwargs.update({"max_tokens": 16384}) @@ -31,34 +28,32 @@ def generate(model, prompt, llms={}, *args, **kwargs): if "stream" not in kwargs: kwargs.update({"stream": True}) - for output in llm(prompt, *args, **kwargs): + prompt = ollama.prompt.template(model_name, prompt) + + model = load(model_name, models=models) + for output in model.create_completion(prompt, *args, **kwargs): yield output -def load(model, llms={}): - llm = llms.get(model, None) - if not llm: - stored_model_path = path.join(ollama.model.models_home, model) + ".bin" - if path.exists(stored_model_path): - model_path = stored_model_path - else: - # try loading this as a path to a model, rather than a model name - model_path = path.abspath(model) - +def load(model_name, models={}): + model = models.get(model_name, None) + if not model: + model_path = path.expanduser(model_name) if not path.exists(model_path): - raise Exception(f"Model not found: {model}") + model_path = path.join(models_home, model_name + ".bin") try: # suppress LLM's output with suppress_stderr(): - llm = LLM(model_path, verbose=False) - llms.update({model: llm}) - except Exception as e: + model = LLM(model_path, verbose=False) + models.update({model_name: model}) + except Exception: # e is sent to devnull, so create a generic exception raise Exception(f"Failed to load model: {model}") - return llm + + return model -def unload(model, llms={}): - if model in llms: - llms.pop(model) +def unload(model_name, models={}): + if model_name in models: + models.pop(model_name) diff --git a/ollama/prompt.py b/ollama/prompt.py index e2bbea75..5e329e3e 100644 --- a/ollama/prompt.py +++ b/ollama/prompt.py @@ -1,16 +1,16 @@ -import os +from os import path from difflib import SequenceMatcher from jinja2 import Environment, PackageLoader -def template(model, prompt): +def template(name, prompt): best_ratio = 0 best_template = '' environment = Environment(loader=PackageLoader(__name__, 'templates')) for template in environment.list_templates(): - base, _ = os.path.splitext(template) - ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio() + base, _ = path.splitext(template) + ratio = SequenceMatcher(None, path.basename(name).lower(), base).ratio() if ratio > best_ratio: best_ratio = ratio best_template = template