remove models home param
This commit is contained in:
parent
54a94566f1
commit
a11cddbf99
4 changed files with 14 additions and 19 deletions
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from yaspin import yaspin
|
from yaspin import yaspin
|
||||||
|
|
||||||
|
@ -10,12 +9,9 @@ from ollama.cmd import server
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
|
|
||||||
|
|
||||||
# create models home if it doesn't exist
|
# create models home if it doesn't exist
|
||||||
models_home = parser.parse_known_args()[0].models_home
|
os.makedirs(model.models_home, exist_ok=True)
|
||||||
if not models_home.exists():
|
|
||||||
os.makedirs(models_home)
|
|
||||||
|
|
||||||
subparsers = parser.add_subparsers()
|
subparsers = parser.add_subparsers()
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ def set_parser(parser):
|
||||||
parser.set_defaults(fn=serve)
|
parser.set_defaults(fn=serve)
|
||||||
|
|
||||||
|
|
||||||
def serve(models_home=".", *args, **kwargs):
|
def serve(*args, **kwargs):
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
|
|
||||||
cors = aiohttp_cors.setup(
|
cors = aiohttp_cors.setup(
|
||||||
|
@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs):
|
||||||
app.update(
|
app.update(
|
||||||
{
|
{
|
||||||
"llms": {},
|
"llms": {},
|
||||||
"models_home": models_home,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,7 +53,6 @@ async def load(request):
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"llms": request.app.get("llms"),
|
"llms": request.app.get("llms"),
|
||||||
"models_home": request.app.get("models_home"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.load(model, **kwargs)
|
engine.load(model, **kwargs)
|
||||||
|
@ -86,7 +84,6 @@ async def generate(request):
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"llms": request.app.get("llms"),
|
"llms": request.app.get("llms"),
|
||||||
"models_home": request.app.get("models_home"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for output in engine.generate(model, prompt, **kwargs):
|
for output in engine.generate(model, prompt, **kwargs):
|
||||||
|
|
|
@ -18,8 +18,8 @@ def suppress_stderr():
|
||||||
os.dup2(stderr, sys.stderr.fileno())
|
os.dup2(stderr, sys.stderr.fileno())
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
def generate(model, prompt, llms={}, *args, **kwargs):
|
||||||
llm = load(model, models_home=models_home, llms=llms)
|
llm = load(model, llms=llms)
|
||||||
|
|
||||||
prompt = ollama.prompt.template(model, prompt)
|
prompt = ollama.prompt.template(model, prompt)
|
||||||
if "max_tokens" not in kwargs:
|
if "max_tokens" not in kwargs:
|
||||||
|
@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
def load(model, models_home=".", llms={}):
|
def load(model, llms={}):
|
||||||
llm = llms.get(model, None)
|
llm = llms.get(model, None)
|
||||||
if not llm:
|
if not llm:
|
||||||
stored_model_path = path.join(models_home, model) + ".bin"
|
stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
|
||||||
if path.exists(stored_model_path):
|
if path.exists(stored_model_path):
|
||||||
model_path = stored_model_path
|
model_path = stored_model_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
import requests
|
import requests
|
||||||
import validators
|
import validators
|
||||||
|
from pathlib import Path
|
||||||
from os import path, walk
|
from os import path, walk
|
||||||
from urllib.parse import urlsplit, urlunsplit
|
from urllib.parse import urlsplit, urlunsplit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
models_endpoint_url = 'https://ollama.ai/api/models'
|
models_endpoint_url = 'https://ollama.ai/api/models'
|
||||||
|
models_home = path.join(Path.home(), '.ollama', 'models')
|
||||||
|
|
||||||
|
|
||||||
def models(models_home='.', *args, **kwargs):
|
def models(*args, **kwargs):
|
||||||
for _, _, files in walk(models_home):
|
for _, _, files in walk(models_home):
|
||||||
for file in files:
|
for file in files:
|
||||||
base, ext = path.splitext(file)
|
base, ext = path.splitext(file)
|
||||||
|
@ -27,7 +29,7 @@ def get_url_from_directory(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def download_from_repo(url, file_name, models_home='.'):
|
def download_from_repo(url, file_name):
|
||||||
parts = urlsplit(url)
|
parts = urlsplit(url)
|
||||||
path_parts = parts.path.split('/tree/')
|
path_parts = parts.path.split('/tree/')
|
||||||
|
|
||||||
|
@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'):
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
|
|
||||||
download_url, file_size = find_bin_file(json_response, location, branch)
|
download_url, file_size = find_bin_file(json_response, location, branch)
|
||||||
return download_file(download_url, models_home, file_name, file_size)
|
return download_file(download_url, file_name, file_size)
|
||||||
|
|
||||||
|
|
||||||
def find_bin_file(json_response, location, branch):
|
def find_bin_file(json_response, location, branch):
|
||||||
|
@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch):
|
||||||
return download_url, file_size
|
return download_url, file_size
|
||||||
|
|
||||||
|
|
||||||
def download_file(download_url, models_home, file_name, file_size):
|
def download_file(download_url, file_name, file_size):
|
||||||
local_filename = path.join(models_home, file_name) + '.bin'
|
local_filename = path.join(models_home, file_name) + '.bin'
|
||||||
|
|
||||||
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
||||||
|
@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size):
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
||||||
|
|
||||||
def pull(model, models_home='.', *args, **kwargs):
|
def pull(model, *args, **kwargs):
|
||||||
if path.exists(model):
|
if path.exists(model):
|
||||||
# a file on the filesystem is being specified
|
# a file on the filesystem is being specified
|
||||||
return model
|
return model
|
||||||
|
@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
return model
|
return model
|
||||||
raise Exception(f'Unknown model {model}')
|
raise Exception(f'Unknown model {model}')
|
||||||
|
|
||||||
local_filename = download_from_repo(url, file_name, models_home)
|
local_filename = download_from_repo(url, file_name)
|
||||||
|
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
Loading…
Reference in a new issue