diff --git a/ollama/engine.py b/ollama/engine.py index c43033ef..63eeb769 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -1,9 +1,8 @@ import os import sys from os import path -from pathlib import Path from contextlib import contextmanager -from fuzzywuzzy import process +from thefuzz import process from llama_cpp import Llama from ctransformers import AutoModelForCausalLM @@ -39,16 +38,14 @@ def load(model_name, models={}): for model_type in cls.model_types() } - while len(runners) > 0: + for match, _ in process.extract(model_path, runners.keys(), limit=len(runners)): try: - best_match, _ = process.extractOne(model_path, runners.keys()) - model = runners.get(best_match, LlamaCppRunner) - runner = model(model_path, best_match) + model = runners.get(match) + runner = model(model_path, match) models.update({model_name: runner}) - return models.get(model_name) + return runner except Exception: - # try the next runner - runners.pop(best_match) + pass raise Exception("failed to load model", model_path, model_name) diff --git a/poetry.lock b/poetry.lock index ccf07a8a..fdd5a55d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -443,23 +443,6 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] -[[package]] -name = "fuzzywuzzy" -version = "0.18.0" -description = "Fuzzy string matching in python" -optional = false -python-versions = "*" -files = [ - {file = "fuzzywuzzy-0.18.0-py2.py3-none-any.whl", hash = "sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993"}, - {file = "fuzzywuzzy-0.18.0.tar.gz", hash = "sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8"}, -] - -[package.dependencies] -python-levenshtein = {version = ">=0.12", optional = true, markers = "extra == \"speedup\""} - -[package.extras] -speedup = ["python-levenshtein (>=0.12)"] - [[package]] name = "huggingface-hub" version = "0.15.1" @@ -1043,6 +1026,23 @@ files = [ [package.extras] tests = ["pytest", "pytest-cov"] +[[package]] +name = "thefuzz" +version = "0.19.0" +description = "Fuzzy string matching in python" +optional = false +python-versions = "*" +files = [ + {file = "thefuzz-0.19.0-py2.py3-none-any.whl", hash = "sha256:4fcdde8e40f5ca5e8106bc7665181f9598a9c8b18b0a4d38c41a095ba6788972"}, + {file = "thefuzz-0.19.0.tar.gz", hash = "sha256:6f7126db2f2c8a54212b05e3a740e45f4291c497d75d20751728f635bb74aa3d"}, +] + +[package.dependencies] +python-levenshtein = {version = ">=0.12", optional = true, markers = "extra == \"speedup\""} + +[package.extras] +speedup = ["python-levenshtein (>=0.12)"] + [[package]] name = "tqdm" version = "4.65.0" @@ -1211,4 +1211,4 @@ termcolor = ">=2.2,<3.0" [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "bd4b373e3903bd26b983163f3cc527a6f768f4280201fcbca4d4dc05dea66912" +content-hash = "9e9c14aae817d7863b4facda840e246f98b2d6b2517085b22eb5c5d919ae7784" diff --git a/pyproject.toml b/pyproject.toml index fdfa4e96..aae2d9e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ validators = "^0.20.0" yaspin = "^2.3.0" llama-cpp-python = "^0.1.67" ctransformers = "^0.2.10" -fuzzywuzzy = {extras = ["speedup"], version = "^0.18.0"} +thefuzz = {version = "^0.19.0", extras = ["speedup"]} [build-system] requires = ["poetry-core"] diff --git a/requirements.txt b/requirements.txt index c0f24b3b..d8cc6520 100644 --- a/requirements.txt +++ b/requirements.txt @@ -270,9 +270,6 @@ frozenlist==1.3.3 ; python_version >= "3.8" and python_version < "4.0" \ fsspec==2023.6.0 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a \ --hash=sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af -fuzzywuzzy[speedup]==0.18.0 ; python_version >= "3.8" and python_version < "4.0" \ - --hash=sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8 \ - --hash=sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993 huggingface-hub==0.15.1 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:05b0fb0abbf1f625dfee864648ac3049fe225ac4371c7bafaca0c2d3a2f83445 \ --hash=sha256:a61b7d1a7769fe10119e730277c72ab99d95c48d86a3d6da3e9f3d0f632a4081 @@ -688,6 +685,9 @@ requests==2.31.0 ; python_version >= "3.8" and python_version < "4.0" \ termcolor==2.3.0 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a +thefuzz[speedup]==0.19.0 ; python_version >= "3.8" and python_version < "4.0" \ + --hash=sha256:4fcdde8e40f5ca5e8106bc7665181f9598a9c8b18b0a4d38c41a095ba6788972 \ + --hash=sha256:6f7126db2f2c8a54212b05e3a740e45f4291c497d75d20751728f635bb74aa3d tqdm==4.65.0 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5 \ --hash=sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671