diff --git a/README.md b/README.md index 3a92c99d..ffdee742 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,6 @@ Download a model ollama.pull("huggingface.co/thebloke/llama-7b-ggml") ``` -## Coming Soon - ### `ollama.search("query")` Search for compatible models that Ollama can run diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index 00f696de..2b4abebe 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -37,14 +37,6 @@ def main(): title='commands', ) - server.set_parser( - subparsers.add_parser( - "serve", - description="Start a persistent server to interact with models via the API.", - help="Start a persistent server to interact with models via the API.", - ) - ) - list_parser = subparsers.add_parser( "models", description="List all available models stored locally.", @@ -52,6 +44,18 @@ def main(): ) list_parser.set_defaults(fn=list_models) + search_parser = subparsers.add_parser( + "search", + description="Search for compatible models that Ollama can run.", + help="Search for compatible models that Ollama can run. Usage: search [model]", + ) + search_parser.add_argument( + "query", + nargs="?", + help="Optional name of the model to search for.", + ) + search_parser.set_defaults(fn=search) + pull_parser = subparsers.add_parser( "pull", description="Download a specified model from a remote source.", @@ -73,6 +77,14 @@ def main(): ) run_parser.set_defaults(fn=run) + server.set_parser( + subparsers.add_parser( + "serve", + description="Start a persistent server to interact with models via the API.", + help="Start a persistent server to interact with models via the API.", + ) + ) + args = parser.parse_args() args = vars(args) @@ -146,6 +158,22 @@ def generate_batch(*args, **kwargs): generate_oneshot(*args, **kwargs) +def search(*args, **kwargs): + try: + model_names = model.search_directory(*args, **kwargs) + if len(model_names) == 0: + print("No models found.") + return + elif len(model_names) == 1: + print(f"Found {len(model_names)} available model:") + else: + print(f"Found {len(model_names)} available models:") + for model_name in model_names: + print(model_name.lower()) + except Exception as e: + print("Failed to fetch available models, check your network connection") + + def pull(*args, **kwargs): model.pull(model_name=kwargs.pop('model'), *args, **kwargs) diff --git a/ollama/engine.py b/ollama/engine.py index 9ba01d34..db1d51c7 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -1,6 +1,7 @@ import os import sys from os import path +from pathlib import Path from contextlib import contextmanager from fuzzywuzzy import process from llama_cpp import Llama @@ -30,7 +31,7 @@ def load(model_name, models={}): if not models.get(model_name, None): model_path = path.expanduser(model_name) if not path.exists(model_path): - model_path = MODELS_CACHE_PATH / model_name + ".bin" + model_path = str(MODELS_CACHE_PATH / (model_name + ".bin")) runners = { model_type: cls @@ -52,14 +53,10 @@ def unload(model_name, models={}): class LlamaCppRunner: - def __init__(self, model_path, model_type): try: with suppress(sys.stderr), suppress(sys.stdout): - self.model = Llama(model_path, - verbose=False, - n_gpu_layers=1, - seed=-1) + self.model = Llama(model_path, verbose=False, n_gpu_layers=1, seed=-1) except Exception: raise Exception("Failed to load model", model_path, model_type) @@ -88,10 +85,10 @@ class LlamaCppRunner: class CtransformerRunner: - def __init__(self, model_path, model_type): self.model = AutoModelForCausalLM.from_pretrained( - model_path, model_type=model_type, local_files_only=True) + model_path, model_type=model_type, local_files_only=True + ) @staticmethod def model_types(): diff --git a/ollama/model.py b/ollama/model.py index c2e2e5dc..5272ee9a 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -18,13 +18,26 @@ def models(*args, **kwargs): yield base +# search the directory and return all models which contain the search term as a substring, +# or all models if no search term is provided +def search_directory(query): + response = requests.get(MODELS_MANIFEST) + response.raise_for_status() + directory = response.json() + model_names = [] + for model_info in directory: + if not query or query.lower() in model_info.get('name', '').lower(): + model_names.append(model_info.get('name')) + return model_names + + # get the url of the model from our curated directory def get_url_from_directory(model): response = requests.get(MODELS_MANIFEST) response.raise_for_status() directory = response.json() for model_info in directory: - if model_info.get('name') == model: + if model_info.get('name').lower() == model.lower(): return model_info.get('url') return model @@ -42,7 +55,6 @@ def download_from_repo(url, file_name): location = location.strip('/') if file_name == '': file_name = path.basename(location).lower() - download_url = urlunsplit( ( 'https', @@ -78,7 +90,7 @@ def find_bin_file(json_response, location, branch): def download_file(download_url, file_name, file_size): - local_filename = MODELS_CACHE_PATH / file_name + '.bin' + local_filename = MODELS_CACHE_PATH / str(file_name + '.bin') first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 @@ -111,7 +123,8 @@ def download_file(download_url, file_name, file_size): def pull(model_name, *args, **kwargs): - if path.exists(model_name): + maybe_existing_model_location = MODELS_CACHE_PATH / str(model_name + '.bin') + if path.exists(model_name) or path.exists(maybe_existing_model_location): # a file on the filesystem is being specified return model_name # check the remote model location and see if it needs to be downloaded @@ -120,7 +133,6 @@ def pull(model_name, *args, **kwargs): if not validators.url(url) and not url.startswith('huggingface.co'): url = get_url_from_directory(model_name) file_name = model_name - if not (url.startswith('http://') or url.startswith('https://')): url = f'https://{url}' diff --git a/ollama/prompt.py b/ollama/prompt.py index 84a16f33..9759b249 100644 --- a/ollama/prompt.py +++ b/ollama/prompt.py @@ -1,9 +1,12 @@ +from os import path from difflib import get_close_matches from jinja2 import Environment, PackageLoader def template(name, prompt): environment = Environment(loader=PackageLoader(__name__, 'templates')) - best_templates = get_close_matches(name, environment.list_templates(), n=1, cutoff=0) + best_templates = get_close_matches( + path.basename(name), environment.list_templates(), n=1, cutoff=0 + ) template = environment.get_template(best_templates.pop()) return template.render(prompt=prompt)