load from file path

This commit is contained in:
Bruce MacDonald 2023-06-27 17:09:35 -04:00
parent ef5c75fd34
commit 843ccf5070

View file

@ -3,6 +3,7 @@ import json
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from llama_cpp import Llama as LLM from llama_cpp import Llama as LLM
from template import template
import ollama.model import ollama.model
@ -10,39 +11,44 @@ import ollama.model
@contextmanager @contextmanager
def suppress_stderr(): def suppress_stderr():
stderr = os.dup(sys.stderr.fileno()) stderr = os.dup(sys.stderr.fileno())
with open(os.devnull, 'w') as devnull: with open(os.devnull, "w") as devnull:
os.dup2(devnull.fileno(), sys.stderr.fileno()) os.dup2(devnull.fileno(), sys.stderr.fileno())
yield yield
os.dup2(stderr, sys.stderr.fileno()) os.dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, models_home='.', llms={}, *args, **kwargs): def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
llm = load(model, models_home=models_home, llms=llms) llm = load(model, models_home=models_home, llms=llms)
if 'max_tokens' not in kwargs: prompt = template(model, prompt)
kwargs.update({'max_tokens': 16384})
if 'stop' not in kwargs: if "max_tokens" not in kwargs:
kwargs.update({'stop': ['Q:', '\n']}) kwargs.update({"max_tokens": 16384})
if 'stream' not in kwargs: if "stop" not in kwargs:
kwargs.update({'stream': True}) kwargs.update({"stop": ["Q:", "\n"]})
if "stream" not in kwargs:
kwargs.update({"stream": True})
for output in llm(prompt, *args, **kwargs): for output in llm(prompt, *args, **kwargs):
yield json.dumps(output) yield json.dumps(output)
def load(model, models_home='.', llms={}): def load(model, models_home=".", llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
model_path = { model_path = {
name: path name: path for name, path in ollama.model.models(models_home)
for name, path in ollama.model.models(models_home)
}.get(model, None) }.get(model, None)
if model_path is None: if model_path is None:
raise ValueError('Model not found') # try loading this as a path to a model, rather than a model name
if os.path.isfile(model):
model_path = model
else:
raise ValueError("Model not found")
# suppress LLM's output # suppress LLM's output
with suppress_stderr(): with suppress_stderr():