diff --git a/README.md b/README.md index 0b03bf61..4c1786cd 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,6 @@ Unload a model ollama.unload("model") ``` -## Cooming Soon - ### `ollama.pull(model)` Download a model @@ -103,6 +101,8 @@ Download a model ollama.pull("huggingface.co/thebloke/llama-7b-ggml") ``` +## Cooming Soon + ### `ollama.search("query")` Search for compatible models that Ollama can run diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index 5e4f3307..f47cdcd2 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -27,6 +27,10 @@ def main(): add_parser.add_argument("model") add_parser.set_defaults(fn=add) + pull_parser = subparsers.add_parser("pull") + pull_parser.add_argument("remote") + pull_parser.set_defaults(fn=pull) + args = parser.parse_args() args = vars(args) @@ -55,3 +59,7 @@ def generate(*args, **kwargs): def add(model, models_home): os.rename(model, Path(models_home) / Path(model).name) + + +def pull(*args, **kwargs): + model.pull(*args, **kwargs) diff --git a/ollama/model.py b/ollama/model.py index b5e7b0b8..0d240432 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -1,9 +1,76 @@ -from os import walk, path +import os +import requests +from urllib.parse import urlsplit, urlunsplit +from tqdm import tqdm -def models(models_home='.', *args, **kwargs): - for root, _, files in walk(models_home): +def models(models_home=".", *args, **kwargs): + for root, _, files in os.walk(models_home): for file in files: - base, ext = path.splitext(file) - if ext == '.bin': - yield base, path.join(root, file) + base, ext = os.path.splitext(file) + if ext == ".bin": + yield base, os.path.join(root, file) + + +def pull(remote, models_home=".", *args, **kwargs): + if not (remote.startswith("http://") or remote.startswith("https://")): + remote = f"https://{remote}" + + parts = urlsplit(remote) + path_parts = parts.path.split("/tree/") + + if len(path_parts) == 1: + model = path_parts[0] + branch = "main" + else: + model, branch = path_parts + + model = model.strip("/") + + # Reconstruct the URL + new_url = urlunsplit( + ( + "https", + parts.netloc, + f"/api/models/{model}/tree/{branch}", + parts.query, + parts.fragment, + ) + ) + + print(f"Fetching model from {new_url}") + + response = requests.get(new_url) + response.raise_for_status() # Raises stored HTTPError, if one occurred + + json_response = response.json() + + for file_info in json_response: + if file_info.get("type") == "file" and file_info.get("path").endswith(".bin"): + f_path = file_info.get("path") + download_url = f"https://huggingface.co/{model}/resolve/{branch}/{f_path}" + local_filename = os.path.join( + models_home, os.path.basename(file_info.get("path")) + ) + + if os.path.exists(local_filename): + # TODO: check if the file is the same + break + + response = requests.get(download_url, stream=True) + response.raise_for_status() # Raises stored HTTPError, if one occurred + + total_size = int(response.headers.get("content-length", 0)) + + with open(local_filename, "wb") as file, tqdm( + desc=local_filename, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + + break # Stop after downloading the first .bin file diff --git a/poetry.lock b/poetry.lock index 25a75447..9eabc319 100644 --- a/poetry.lock +++ b/poetry.lock @@ -165,11 +165,22 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +[[package]] +name = "certifi" +version = "2023.5.7" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, + {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, +] + [[package]] name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -optional = true +optional = false python-versions = ">=3.7.0" files = [ {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"}, @@ -249,6 +260,17 @@ files = [ {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + [[package]] name = "diskcache" version = "5.6.1" @@ -374,7 +396,7 @@ files = [ name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, @@ -680,6 +702,27 @@ test = ["coverage", "flaky", "matplotlib", "numpy", "pandas", "pylint (>=2.5.0,< websockets = ["websockets (>=10.3)"] yapf = ["whatthepatch (>=1.0.2,<2.0.0)", "yapf (>=0.33.0)"] +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + [[package]] name = "setuptools" version = "68.0.0" @@ -696,6 +739,26 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "tqdm" +version = "4.65.0" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, + {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "typing-extensions" version = "4.6.3" @@ -777,6 +840,23 @@ files = [ {file = "ujson-5.8.0.tar.gz", hash = "sha256:78e318def4ade898a461b3d92a79f9441e7e0e4d2ad5419abed4336d702c7425"}, ] +[[package]] +name = "urllib3" +version = "2.0.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"}, + {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "yarl" version = "1.9.2" @@ -870,4 +950,4 @@ server = ["aiohttp", "aiohttp-cors"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c649ffbb8b8045d35831f9fb09a6f099b8e940abd85a020c3dfd24173b2582d8" +content-hash = "ba168754266c6c46b2136207415a5b3a879c957e53e924cab1e64267849ceb90" diff --git a/pyproject.toml b/pyproject.toml index 0b59227e..fc2ebf0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ llama-cpp-python = "^0.1.66" aiohttp = {version = "^3.8.4", optional = true} aiohttp-cors = {version = "^0.7.0", optional = true} +requests = "^2.31.0" +tqdm = "^4.65.0" [tool.poetry.extras] server = ["aiohttp", "aiohttp_cors"]