rename model to model_name to avoid collision

This commit is contained in:
Michael Yang 2023-06-28 12:25:08 -07:00 committed by Michael Yang
parent af66c5695a
commit 5cea13ce00
4 changed files with 37 additions and 42 deletions

View file

@ -111,7 +111,7 @@ def generate_oneshot(*args, **kwargs):
spinner.start() spinner.start()
spinner_running = True spinner_running = True
try: try:
for output in engine.generate(*args, **kwargs): for output in engine.generate(model_name=kwargs.pop('model'), *args, **kwargs):
choices = output.get("choices", []) choices = output.get("choices", [])
if len(choices) > 0: if len(choices) > 0:
if spinner_running: if spinner_running:
@ -147,7 +147,7 @@ def generate_batch(*args, **kwargs):
def pull(*args, **kwargs): def pull(*args, **kwargs):
model.pull(*args, **kwargs) model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
def run(*args, **kwargs): def run(*args, **kwargs):

View file

@ -38,7 +38,7 @@ def serve(*args, **kwargs):
app.update( app.update(
{ {
"llms": {}, "models": {},
} }
) )
@ -47,32 +47,32 @@ def serve(*args, **kwargs):
async def load(request): async def load(request):
body = await request.json() body = await request.json()
model = body.get("model") name = body.get("model")
if not model: if not name:
raise web.HTTPBadRequest() raise web.HTTPBadRequest()
kwargs = { kwargs = {
"llms": request.app.get("llms"), "models": request.app.get("models"),
} }
engine.load(model, **kwargs) engine.load(name, **kwargs)
return web.Response() return web.Response()
async def unload(request): async def unload(request):
body = await request.json() body = await request.json()
model = body.get("model") name = body.get("model")
if not model: if not name:
raise web.HTTPBadRequest() raise web.HTTPBadRequest()
engine.unload(model, llms=request.app.get("llms")) engine.unload(name, models=request.app.get("models"))
return web.Response() return web.Response()
async def generate(request): async def generate(request):
body = await request.json() body = await request.json()
model = body.get("model") name = body.get("model")
if not model: if not name:
raise web.HTTPBadRequest() raise web.HTTPBadRequest()
prompt = body.get("prompt") prompt = body.get("prompt")
@ -83,10 +83,10 @@ async def generate(request):
await response.prepare(request) await response.prepare(request)
kwargs = { kwargs = {
"llms": request.app.get("llms"), "models": request.app.get("models"),
} }
for output in engine.generate(model, prompt, **kwargs): for output in engine.generate(name, prompt, **kwargs):
output = json.dumps(output).encode('utf-8') output = json.dumps(output).encode('utf-8')
await response.write(output) await response.write(output)
await response.write(b"\n") await response.write(b"\n")

View file

@ -4,8 +4,8 @@ from os import path
from contextlib import contextmanager from contextlib import contextmanager
from llama_cpp import Llama as LLM from llama_cpp import Llama as LLM
import ollama.model
import ollama.prompt import ollama.prompt
from ollama.model import models_home
@contextmanager @contextmanager
@ -18,10 +18,7 @@ def suppress_stderr():
os.dup2(stderr, sys.stderr.fileno()) os.dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, llms={}, *args, **kwargs): def generate(model_name, prompt, models={}, *args, **kwargs):
llm = load(model, llms=llms)
prompt = ollama.prompt.template(model, prompt)
if "max_tokens" not in kwargs: if "max_tokens" not in kwargs:
kwargs.update({"max_tokens": 16384}) kwargs.update({"max_tokens": 16384})
@ -31,34 +28,32 @@ def generate(model, prompt, llms={}, *args, **kwargs):
if "stream" not in kwargs: if "stream" not in kwargs:
kwargs.update({"stream": True}) kwargs.update({"stream": True})
for output in llm(prompt, *args, **kwargs): prompt = ollama.prompt.template(model_name, prompt)
model = load(model_name, models=models)
for output in model.create_completion(prompt, *args, **kwargs):
yield output yield output
def load(model, llms={}): def load(model_name, models={}):
llm = llms.get(model, None) model = models.get(model_name, None)
if not llm: if not model:
stored_model_path = path.join(ollama.model.models_home, model) + ".bin" model_path = path.expanduser(model_name)
if path.exists(stored_model_path):
model_path = stored_model_path
else:
# try loading this as a path to a model, rather than a model name
model_path = path.abspath(model)
if not path.exists(model_path): if not path.exists(model_path):
raise Exception(f"Model not found: {model}") model_path = path.join(models_home, model_name + ".bin")
try: try:
# suppress LLM's output # suppress LLM's output
with suppress_stderr(): with suppress_stderr():
llm = LLM(model_path, verbose=False) model = LLM(model_path, verbose=False)
llms.update({model: llm}) models.update({model_name: model})
except Exception as e: except Exception:
# e is sent to devnull, so create a generic exception # e is sent to devnull, so create a generic exception
raise Exception(f"Failed to load model: {model}") raise Exception(f"Failed to load model: {model}")
return llm
return model
def unload(model, llms={}): def unload(model_name, models={}):
if model in llms: if model_name in models:
llms.pop(model) models.pop(model_name)

View file

@ -1,16 +1,16 @@
import os from os import path
from difflib import SequenceMatcher from difflib import SequenceMatcher
from jinja2 import Environment, PackageLoader from jinja2 import Environment, PackageLoader
def template(model, prompt): def template(name, prompt):
best_ratio = 0 best_ratio = 0
best_template = '' best_template = ''
environment = Environment(loader=PackageLoader(__name__, 'templates')) environment = Environment(loader=PackageLoader(__name__, 'templates'))
for template in environment.list_templates(): for template in environment.list_templates():
base, _ = os.path.splitext(template) base, _ = path.splitext(template)
ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio() ratio = SequenceMatcher(None, path.basename(name).lower(), base).ratio()
if ratio > best_ratio: if ratio > best_ratio:
best_ratio = ratio best_ratio = ratio
best_template = template best_template = template