diff --git a/ollama/engine.py b/ollama/engine.py index efcb460f..39a60da9 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -22,7 +22,6 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): llm = load(model, models_home=models_home, llms=llms) prompt = ollama.prompt.template(model, prompt) - if "max_tokens" not in kwargs: kwargs.update({"max_tokens": 16384}) @@ -39,11 +38,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def load(model, models_home=".", llms={}): llm = llms.get(model, None) if not llm: - model_path = { - name: path for name, path in ollama.model.models(models_home) - }.get(model, None) - - if not model_path: + stored_model_path = os.path.join(models_home, model, ".bin") + if os.path.exists(stored_model_path): + model_path = stored_model_path + else: # try loading this as a path to a model, rather than a model name model_path = os.path.abspath(model) diff --git a/ollama/model.py b/ollama/model.py index 747faf97..41256c04 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -16,26 +16,18 @@ def models(models_home='.', *args, **kwargs): yield base -def pull(model, models_home='.', *args, **kwargs): - url = model - if not validators.url(url) and not url.startswith('huggingface.co'): - # this may just be a local model location - if model in models(models_home): - return model - # see if we have this model in our directory - response = requests.get(models_endpoint_url) - response.raise_for_status() - directory = response.json() - for model_info in directory: - if model_info.get('name') == model: - url = f"https://{model_info.get('url')}" - break - if not validators.url(url): - raise Exception(f'Unknown model {model}') +# get the url of the model from our curated directory +def get_url_from_directory(model): + response = requests.get(models_endpoint_url) + response.raise_for_status() + directory = response.json() + for model_info in directory: + if model_info.get('name') == model: + return model_info.get('url') + return model - if not (url.startswith('http://') or url.startswith('https://')): - url = f'https://{url}' +def download_from_repo(url, models_home='.'): parts = urlsplit(url) path_parts = parts.path.split('/tree/') @@ -47,7 +39,6 @@ def pull(model, models_home='.', *args, **kwargs): location = location.strip('/') - # Reconstruct the URL download_url = urlunsplit( ( 'https', @@ -57,13 +48,15 @@ def pull(model, models_home='.', *args, **kwargs): parts.fragment, ) ) - response = requests.get(download_url) - response.raise_for_status() # Raises stored HTTPError, if one occurred - + response.raise_for_status() json_response = response.json() - # get the last bin file we find, this is probably the most up to date + download_url, file_size = find_bin_file(json_response, location, branch) + return download_file(download_url, models_home, location, file_size) + + +def find_bin_file(json_response, location, branch): download_url = None file_size = 0 for file_info in json_response: @@ -77,27 +70,25 @@ def pull(model, models_home='.', *args, **kwargs): if download_url is None: raise Exception('No model found') - local_filename = os.path.join(models_home, os.path.basename(url)) + '.bin' + return download_url, file_size - # Check if file already exists - first_byte = 0 - if os.path.exists(local_filename): - # TODO: check if the file is the same SHA - first_byte = os.path.getsize(local_filename) + +def download_file(download_url, models_home, location, file_size): + local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin' + + first_byte = ( + os.path.getsize(local_filename) if os.path.exists(local_filename) else 0 + ) if first_byte >= file_size: return local_filename - print(f'Pulling {model}...') + print(f'Pulling {os.path.basename(location)}...') - # If file size is non-zero, resume download - if first_byte != 0: - header = {'Range': f'bytes={first_byte}-'} - else: - header = {} + header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {} response = requests.get(download_url, headers=header, stream=True) - response.raise_for_status() # Raises stored HTTPError, if one occurred + response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) @@ -115,3 +106,26 @@ def pull(model, models_home='.', *args, **kwargs): bar.update(size) return local_filename + + +def pull(model, models_home='.', *args, **kwargs): + if os.path.exists(model): + # a file on the filesystem is being specified + return model + # check the remote model location and see if it needs to be downloaded + url = model + if not validators.url(url) and not url.startswith('huggingface.co'): + url = get_url_from_directory(model) + + if not (url.startswith('http://') or url.startswith('https://')): + url = f'https://{url}' + + if not validators.url(url): + if model in models(models_home): + # the model is already downloaded, and specified by name + return model + raise Exception(f'Unknown model {model}') + + local_filename = download_from_repo(url, models_home) + + return local_filename diff --git a/ollama/prompt.py b/ollama/prompt.py index 437bc8c8..e2bbea75 100644 --- a/ollama/prompt.py +++ b/ollama/prompt.py @@ -1,4 +1,4 @@ -from os import path +import os from difflib import SequenceMatcher from jinja2 import Environment, PackageLoader @@ -9,8 +9,8 @@ def template(model, prompt): environment = Environment(loader=PackageLoader(__name__, 'templates')) for template in environment.list_templates(): - base, _ = path.splitext(template) - ratio = SequenceMatcher(None, model.lower(), base).ratio() + base, _ = os.path.splitext(template) + ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio() if ratio > best_ratio: best_ratio = ratio best_template = template