pull from remote

This commit is contained in:
Bruce MacDonald 2023-06-28 12:13:13 -04:00
parent 1d0c84a6c7
commit 52beb0a99e
5 changed files with 168 additions and 11 deletions

View file

@ -93,8 +93,6 @@ Unload a model
ollama.unload("model")
```
## Cooming Soon
### `ollama.pull(model)`
Download a model
@ -103,6 +101,8 @@ Download a model
ollama.pull("huggingface.co/thebloke/llama-7b-ggml")
```
## Cooming Soon
### `ollama.search("query")`
Search for compatible models that Ollama can run

View file

@ -27,6 +27,10 @@ def main():
add_parser.add_argument("model")
add_parser.set_defaults(fn=add)
pull_parser = subparsers.add_parser("pull")
pull_parser.add_argument("remote")
pull_parser.set_defaults(fn=pull)
args = parser.parse_args()
args = vars(args)
@ -55,3 +59,7 @@ def generate(*args, **kwargs):
def add(model, models_home):
os.rename(model, Path(models_home) / Path(model).name)
def pull(*args, **kwargs):
model.pull(*args, **kwargs)

View file

@ -1,9 +1,76 @@
from os import walk, path
import os
import requests
from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm
def models(models_home='.', *args, **kwargs):
for root, _, files in walk(models_home):
def models(models_home=".", *args, **kwargs):
for root, _, files in os.walk(models_home):
for file in files:
base, ext = path.splitext(file)
if ext == '.bin':
yield base, path.join(root, file)
base, ext = os.path.splitext(file)
if ext == ".bin":
yield base, os.path.join(root, file)
def pull(remote, models_home=".", *args, **kwargs):
if not (remote.startswith("http://") or remote.startswith("https://")):
remote = f"https://{remote}"
parts = urlsplit(remote)
path_parts = parts.path.split("/tree/")
if len(path_parts) == 1:
model = path_parts[0]
branch = "main"
else:
model, branch = path_parts
model = model.strip("/")
# Reconstruct the URL
new_url = urlunsplit(
(
"https",
parts.netloc,
f"/api/models/{model}/tree/{branch}",
parts.query,
parts.fragment,
)
)
print(f"Fetching model from {new_url}")
response = requests.get(new_url)
response.raise_for_status() # Raises stored HTTPError, if one occurred
json_response = response.json()
for file_info in json_response:
if file_info.get("type") == "file" and file_info.get("path").endswith(".bin"):
f_path = file_info.get("path")
download_url = f"https://huggingface.co/{model}/resolve/{branch}/{f_path}"
local_filename = os.path.join(
models_home, os.path.basename(file_info.get("path"))
)
if os.path.exists(local_filename):
# TODO: check if the file is the same
break
response = requests.get(download_url, stream=True)
response.raise_for_status() # Raises stored HTTPError, if one occurred
total_size = int(response.headers.get("content-length", 0))
with open(local_filename, "wb") as file, tqdm(
desc=local_filename,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
break # Stop after downloading the first .bin file

86
poetry.lock generated
View file

@ -165,11 +165,22 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
[[package]]
name = "certifi"
version = "2023.5.7"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
files = [
{file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"},
{file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"},
]
[[package]]
name = "charset-normalizer"
version = "3.1.0"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
optional = true
optional = false
python-versions = ">=3.7.0"
files = [
{file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"},
@ -249,6 +260,17 @@ files = [
{file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
]
[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "diskcache"
version = "5.6.1"
@ -374,7 +396,7 @@ files = [
name = "idna"
version = "3.4"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = true
optional = false
python-versions = ">=3.5"
files = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
@ -680,6 +702,27 @@ test = ["coverage", "flaky", "matplotlib", "numpy", "pandas", "pylint (>=2.5.0,<
websockets = ["websockets (>=10.3)"]
yapf = ["whatthepatch (>=1.0.2,<2.0.0)", "yapf (>=0.33.0)"]
[[package]]
name = "requests"
version = "2.31.0"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.7"
files = [
{file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
{file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
]
[package.dependencies]
certifi = ">=2017.4.17"
charset-normalizer = ">=2,<4"
idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<3"
[package.extras]
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "setuptools"
version = "68.0.0"
@ -696,6 +739,26 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "tqdm"
version = "4.65.0"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"},
{file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["py-make (>=0.1.0)", "twine", "wheel"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typing-extensions"
version = "4.6.3"
@ -777,6 +840,23 @@ files = [
{file = "ujson-5.8.0.tar.gz", hash = "sha256:78e318def4ade898a461b3d92a79f9441e7e0e4d2ad5419abed4336d702c7425"},
]
[[package]]
name = "urllib3"
version = "2.0.3"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=3.7"
files = [
{file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"},
{file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "yarl"
version = "1.9.2"
@ -870,4 +950,4 @@ server = ["aiohttp", "aiohttp-cors"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "c649ffbb8b8045d35831f9fb09a6f099b8e940abd85a020c3dfd24173b2582d8"
content-hash = "ba168754266c6c46b2136207415a5b3a879c957e53e924cab1e64267849ceb90"

View file

@ -13,6 +13,8 @@ llama-cpp-python = "^0.1.66"
aiohttp = {version = "^3.8.4", optional = true}
aiohttp-cors = {version = "^0.7.0", optional = true}
requests = "^2.31.0"
tqdm = "^4.65.0"
[tool.poetry.extras]
server = ["aiohttp", "aiohttp_cors"]