rename model to model_name to avoid collision
This commit is contained in:
parent
af66c5695a
commit
5cea13ce00
4 changed files with 37 additions and 42 deletions
|
@ -111,7 +111,7 @@ def generate_oneshot(*args, **kwargs):
|
||||||
spinner.start()
|
spinner.start()
|
||||||
spinner_running = True
|
spinner_running = True
|
||||||
try:
|
try:
|
||||||
for output in engine.generate(*args, **kwargs):
|
for output in engine.generate(model_name=kwargs.pop('model'), *args, **kwargs):
|
||||||
choices = output.get("choices", [])
|
choices = output.get("choices", [])
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
if spinner_running:
|
if spinner_running:
|
||||||
|
@ -147,7 +147,7 @@ def generate_batch(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def pull(*args, **kwargs):
|
def pull(*args, **kwargs):
|
||||||
model.pull(*args, **kwargs)
|
model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run(*args, **kwargs):
|
def run(*args, **kwargs):
|
||||||
|
|
|
@ -38,7 +38,7 @@ def serve(*args, **kwargs):
|
||||||
|
|
||||||
app.update(
|
app.update(
|
||||||
{
|
{
|
||||||
"llms": {},
|
"models": {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,32 +47,32 @@ def serve(*args, **kwargs):
|
||||||
|
|
||||||
async def load(request):
|
async def load(request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
model = body.get("model")
|
name = body.get("model")
|
||||||
if not model:
|
if not name:
|
||||||
raise web.HTTPBadRequest()
|
raise web.HTTPBadRequest()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"llms": request.app.get("llms"),
|
"models": request.app.get("models"),
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.load(model, **kwargs)
|
engine.load(name, **kwargs)
|
||||||
return web.Response()
|
return web.Response()
|
||||||
|
|
||||||
|
|
||||||
async def unload(request):
|
async def unload(request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
model = body.get("model")
|
name = body.get("model")
|
||||||
if not model:
|
if not name:
|
||||||
raise web.HTTPBadRequest()
|
raise web.HTTPBadRequest()
|
||||||
|
|
||||||
engine.unload(model, llms=request.app.get("llms"))
|
engine.unload(name, models=request.app.get("models"))
|
||||||
return web.Response()
|
return web.Response()
|
||||||
|
|
||||||
|
|
||||||
async def generate(request):
|
async def generate(request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
model = body.get("model")
|
name = body.get("model")
|
||||||
if not model:
|
if not name:
|
||||||
raise web.HTTPBadRequest()
|
raise web.HTTPBadRequest()
|
||||||
|
|
||||||
prompt = body.get("prompt")
|
prompt = body.get("prompt")
|
||||||
|
@ -83,10 +83,10 @@ async def generate(request):
|
||||||
await response.prepare(request)
|
await response.prepare(request)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"llms": request.app.get("llms"),
|
"models": request.app.get("models"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for output in engine.generate(model, prompt, **kwargs):
|
for output in engine.generate(name, prompt, **kwargs):
|
||||||
output = json.dumps(output).encode('utf-8')
|
output = json.dumps(output).encode('utf-8')
|
||||||
await response.write(output)
|
await response.write(output)
|
||||||
await response.write(b"\n")
|
await response.write(b"\n")
|
||||||
|
|
|
@ -4,8 +4,8 @@ from os import path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from llama_cpp import Llama as LLM
|
from llama_cpp import Llama as LLM
|
||||||
|
|
||||||
import ollama.model
|
|
||||||
import ollama.prompt
|
import ollama.prompt
|
||||||
|
from ollama.model import models_home
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -18,10 +18,7 @@ def suppress_stderr():
|
||||||
os.dup2(stderr, sys.stderr.fileno())
|
os.dup2(stderr, sys.stderr.fileno())
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt, llms={}, *args, **kwargs):
|
def generate(model_name, prompt, models={}, *args, **kwargs):
|
||||||
llm = load(model, llms=llms)
|
|
||||||
|
|
||||||
prompt = ollama.prompt.template(model, prompt)
|
|
||||||
if "max_tokens" not in kwargs:
|
if "max_tokens" not in kwargs:
|
||||||
kwargs.update({"max_tokens": 16384})
|
kwargs.update({"max_tokens": 16384})
|
||||||
|
|
||||||
|
@ -31,34 +28,32 @@ def generate(model, prompt, llms={}, *args, **kwargs):
|
||||||
if "stream" not in kwargs:
|
if "stream" not in kwargs:
|
||||||
kwargs.update({"stream": True})
|
kwargs.update({"stream": True})
|
||||||
|
|
||||||
for output in llm(prompt, *args, **kwargs):
|
prompt = ollama.prompt.template(model_name, prompt)
|
||||||
|
|
||||||
|
model = load(model_name, models=models)
|
||||||
|
for output in model.create_completion(prompt, *args, **kwargs):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
def load(model, llms={}):
|
def load(model_name, models={}):
|
||||||
llm = llms.get(model, None)
|
model = models.get(model_name, None)
|
||||||
if not llm:
|
if not model:
|
||||||
stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
|
model_path = path.expanduser(model_name)
|
||||||
if path.exists(stored_model_path):
|
|
||||||
model_path = stored_model_path
|
|
||||||
else:
|
|
||||||
# try loading this as a path to a model, rather than a model name
|
|
||||||
model_path = path.abspath(model)
|
|
||||||
|
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
raise Exception(f"Model not found: {model}")
|
model_path = path.join(models_home, model_name + ".bin")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# suppress LLM's output
|
# suppress LLM's output
|
||||||
with suppress_stderr():
|
with suppress_stderr():
|
||||||
llm = LLM(model_path, verbose=False)
|
model = LLM(model_path, verbose=False)
|
||||||
llms.update({model: llm})
|
models.update({model_name: model})
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# e is sent to devnull, so create a generic exception
|
# e is sent to devnull, so create a generic exception
|
||||||
raise Exception(f"Failed to load model: {model}")
|
raise Exception(f"Failed to load model: {model}")
|
||||||
return llm
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def unload(model, llms={}):
|
def unload(model_name, models={}):
|
||||||
if model in llms:
|
if model_name in models:
|
||||||
llms.pop(model)
|
models.pop(model_name)
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
import os
|
from os import path
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from jinja2 import Environment, PackageLoader
|
from jinja2 import Environment, PackageLoader
|
||||||
|
|
||||||
|
|
||||||
def template(model, prompt):
|
def template(name, prompt):
|
||||||
best_ratio = 0
|
best_ratio = 0
|
||||||
best_template = ''
|
best_template = ''
|
||||||
|
|
||||||
environment = Environment(loader=PackageLoader(__name__, 'templates'))
|
environment = Environment(loader=PackageLoader(__name__, 'templates'))
|
||||||
for template in environment.list_templates():
|
for template in environment.list_templates():
|
||||||
base, _ = os.path.splitext(template)
|
base, _ = path.splitext(template)
|
||||||
ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio()
|
ratio = SequenceMatcher(None, path.basename(name).lower(), base).ratio()
|
||||||
if ratio > best_ratio:
|
if ratio > best_ratio:
|
||||||
best_ratio = ratio
|
best_ratio = ratio
|
||||||
best_template = template
|
best_template = template
|
||||||
|
|
Loading…
Reference in a new issue