diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index 61af0131..00f696de 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -151,7 +151,7 @@ def pull(*args, **kwargs): def run(*args, **kwargs): - name = model.pull(*args, **kwargs) + name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs) kwargs.update({"model": name}) print(f"Running {name}...") generate(*args, **kwargs) diff --git a/ollama/engine.py b/ollama/engine.py index 91506c4a..9ba01d34 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -30,7 +30,7 @@ def load(model_name, models={}): if not models.get(model_name, None): model_path = path.expanduser(model_name) if not path.exists(model_path): - model_path = path.join(MODELS_CACHE_PATH, model_name + ".bin") + model_path = MODELS_CACHE_PATH / model_name + ".bin" runners = { model_type: cls diff --git a/ollama/model.py b/ollama/model.py index ec59686c..c2e2e5dc 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -7,7 +7,7 @@ from tqdm import tqdm MODELS_MANIFEST = 'https://ollama.ai/api/models' -MODELS_CACHE_PATH = path.join(Path.home(), '.ollama', 'models') +MODELS_CACHE_PATH = Path.home() / '.ollama' / 'models' def models(*args, **kwargs): @@ -78,7 +78,7 @@ def find_bin_file(json_response, location, branch): def download_file(download_url, file_name, file_size): - local_filename = path.join(MODELS_CACHE_PATH, file_name) + '.bin' + local_filename = MODELS_CACHE_PATH / file_name + '.bin' first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 @@ -110,25 +110,25 @@ def download_file(download_url, file_name, file_size): return local_filename -def pull(model, *args, **kwargs): - if path.exists(model): +def pull(model_name, *args, **kwargs): + if path.exists(model_name): # a file on the filesystem is being specified - return model + return model_name # check the remote model location and see if it needs to be downloaded - url = model + url = model_name file_name = "" if not validators.url(url) and not url.startswith('huggingface.co'): - url = get_url_from_directory(model) - file_name = model + url = get_url_from_directory(model_name) + file_name = model_name if not (url.startswith('http://') or url.startswith('https://')): url = f'https://{url}' if not validators.url(url): - if model in models(MODELS_CACHE_PATH): + if model_name in models(MODELS_CACHE_PATH): # the model is already downloaded, and specified by name - return model - raise Exception(f'Unknown model {model}') + return model_name + raise Exception(f'Unknown model {model_name}') local_filename = download_from_repo(url, file_name)