remove models home param

This commit is contained in:
Bruce MacDonald 2023-06-30 11:37:03 -04:00
parent 54a94566f1
commit a11cddbf99
4 changed files with 14 additions and 19 deletions

View file

@ -1,6 +1,5 @@
import os import os
import sys import sys
from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
from yaspin import yaspin from yaspin import yaspin
@ -10,12 +9,9 @@ from ollama.cmd import server
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
# create models home if it doesn't exist # create models home if it doesn't exist
models_home = parser.parse_known_args()[0].models_home os.makedirs(model.models_home, exist_ok=True)
if not models_home.exists():
os.makedirs(models_home)
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()

View file

@ -11,7 +11,7 @@ def set_parser(parser):
parser.set_defaults(fn=serve) parser.set_defaults(fn=serve)
def serve(models_home=".", *args, **kwargs): def serve(*args, **kwargs):
app = web.Application() app = web.Application()
cors = aiohttp_cors.setup( cors = aiohttp_cors.setup(
@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs):
app.update( app.update(
{ {
"llms": {}, "llms": {},
"models_home": models_home,
} }
) )
@ -54,7 +53,6 @@ async def load(request):
kwargs = { kwargs = {
"llms": request.app.get("llms"), "llms": request.app.get("llms"),
"models_home": request.app.get("models_home"),
} }
engine.load(model, **kwargs) engine.load(model, **kwargs)
@ -86,7 +84,6 @@ async def generate(request):
kwargs = { kwargs = {
"llms": request.app.get("llms"), "llms": request.app.get("llms"),
"models_home": request.app.get("models_home"),
} }
for output in engine.generate(model, prompt, **kwargs): for output in engine.generate(model, prompt, **kwargs):

View file

@ -18,8 +18,8 @@ def suppress_stderr():
os.dup2(stderr, sys.stderr.fileno()) os.dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def generate(model, prompt, llms={}, *args, **kwargs):
llm = load(model, models_home=models_home, llms=llms) llm = load(model, llms=llms)
prompt = ollama.prompt.template(model, prompt) prompt = ollama.prompt.template(model, prompt)
if "max_tokens" not in kwargs: if "max_tokens" not in kwargs:
@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
yield output yield output
def load(model, models_home=".", llms={}): def load(model, llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
stored_model_path = path.join(models_home, model) + ".bin" stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
if path.exists(stored_model_path): if path.exists(stored_model_path):
model_path = stored_model_path model_path = stored_model_path
else: else:

View file

@ -1,14 +1,16 @@
import requests import requests
import validators import validators
from pathlib import Path
from os import path, walk from os import path, walk
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm from tqdm import tqdm
models_endpoint_url = 'https://ollama.ai/api/models' models_endpoint_url = 'https://ollama.ai/api/models'
models_home = path.join(Path.home(), '.ollama', 'models')
def models(models_home='.', *args, **kwargs): def models(*args, **kwargs):
for _, _, files in walk(models_home): for _, _, files in walk(models_home):
for file in files: for file in files:
base, ext = path.splitext(file) base, ext = path.splitext(file)
@ -27,7 +29,7 @@ def get_url_from_directory(model):
return model return model
def download_from_repo(url, file_name, models_home='.'): def download_from_repo(url, file_name):
parts = urlsplit(url) parts = urlsplit(url)
path_parts = parts.path.split('/tree/') path_parts = parts.path.split('/tree/')
@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'):
json_response = response.json() json_response = response.json()
download_url, file_size = find_bin_file(json_response, location, branch) download_url, file_size = find_bin_file(json_response, location, branch)
return download_file(download_url, models_home, file_name, file_size) return download_file(download_url, file_name, file_size)
def find_bin_file(json_response, location, branch): def find_bin_file(json_response, location, branch):
@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch):
return download_url, file_size return download_url, file_size
def download_file(download_url, models_home, file_name, file_size): def download_file(download_url, file_name, file_size):
local_filename = path.join(models_home, file_name) + '.bin' local_filename = path.join(models_home, file_name) + '.bin'
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size):
return local_filename return local_filename
def pull(model, models_home='.', *args, **kwargs): def pull(model, *args, **kwargs):
if path.exists(model): if path.exists(model):
# a file on the filesystem is being specified # a file on the filesystem is being specified
return model return model
@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs):
return model return model
raise Exception(f'Unknown model {model}') raise Exception(f'Unknown model {model}')
local_filename = download_from_repo(url, file_name, models_home) local_filename = download_from_repo(url, file_name)
return local_filename return local_filename