diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index d19cc47b..4f9d580f 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -1,6 +1,5 @@ import os import sys -from pathlib import Path from argparse import ArgumentParser from yaspin import yaspin @@ -10,12 +9,9 @@ from ollama.cmd import server def main(): parser = ArgumentParser() - parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models") # create models home if it doesn't exist - models_home = parser.parse_known_args()[0].models_home - if not models_home.exists(): - os.makedirs(models_home) + os.makedirs(model.models_home, exist_ok=True) subparsers = parser.add_subparsers() diff --git a/ollama/cmd/server.py b/ollama/cmd/server.py index 11e478bd..fe803c17 100644 --- a/ollama/cmd/server.py +++ b/ollama/cmd/server.py @@ -11,7 +11,7 @@ def set_parser(parser): parser.set_defaults(fn=serve) -def serve(models_home=".", *args, **kwargs): +def serve(*args, **kwargs): app = web.Application() cors = aiohttp_cors.setup( @@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs): app.update( { "llms": {}, - "models_home": models_home, } ) @@ -54,7 +53,6 @@ async def load(request): kwargs = { "llms": request.app.get("llms"), - "models_home": request.app.get("models_home"), } engine.load(model, **kwargs) @@ -86,7 +84,6 @@ async def generate(request): kwargs = { "llms": request.app.get("llms"), - "models_home": request.app.get("models_home"), } for output in engine.generate(model, prompt, **kwargs): diff --git a/ollama/engine.py b/ollama/engine.py index e6bbadb5..67c5cce9 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -18,8 +18,8 @@ def suppress_stderr(): os.dup2(stderr, sys.stderr.fileno()) -def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): - llm = load(model, models_home=models_home, llms=llms) +def generate(model, prompt, llms={}, *args, **kwargs): + llm = load(model, llms=llms) prompt = ollama.prompt.template(model, prompt) if "max_tokens" not in kwargs: @@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): yield output -def load(model, models_home=".", llms={}): +def load(model, llms={}): llm = llms.get(model, None) if not llm: - stored_model_path = path.join(models_home, model) + ".bin" + stored_model_path = path.join(ollama.model.models_home, model) + ".bin" if path.exists(stored_model_path): model_path = stored_model_path else: diff --git a/ollama/model.py b/ollama/model.py index b1b686e8..55b98b51 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -1,14 +1,16 @@ import requests import validators +from pathlib import Path from os import path, walk from urllib.parse import urlsplit, urlunsplit from tqdm import tqdm models_endpoint_url = 'https://ollama.ai/api/models' +models_home = path.join(Path.home(), '.ollama', 'models') -def models(models_home='.', *args, **kwargs): +def models(*args, **kwargs): for _, _, files in walk(models_home): for file in files: base, ext = path.splitext(file) @@ -27,7 +29,7 @@ def get_url_from_directory(model): return model -def download_from_repo(url, file_name, models_home='.'): +def download_from_repo(url, file_name): parts = urlsplit(url) path_parts = parts.path.split('/tree/') @@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'): json_response = response.json() download_url, file_size = find_bin_file(json_response, location, branch) - return download_file(download_url, models_home, file_name, file_size) + return download_file(download_url, file_name, file_size) def find_bin_file(json_response, location, branch): @@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch): return download_url, file_size -def download_file(download_url, models_home, file_name, file_size): +def download_file(download_url, file_name, file_size): local_filename = path.join(models_home, file_name) + '.bin' first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 @@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size): return local_filename -def pull(model, models_home='.', *args, **kwargs): +def pull(model, *args, **kwargs): if path.exists(model): # a file on the filesystem is being specified return model @@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs): return model raise Exception(f'Unknown model {model}') - local_filename = download_from_repo(url, file_name, models_home) + local_filename = download_from_repo(url, file_name) return local_filename