pull from remote
This commit is contained in:
parent
1d0c84a6c7
commit
52beb0a99e
5 changed files with 168 additions and 11 deletions
|
@ -93,8 +93,6 @@ Unload a model
|
|||
ollama.unload("model")
|
||||
```
|
||||
|
||||
## Cooming Soon
|
||||
|
||||
### `ollama.pull(model)`
|
||||
|
||||
Download a model
|
||||
|
@ -103,6 +101,8 @@ Download a model
|
|||
ollama.pull("huggingface.co/thebloke/llama-7b-ggml")
|
||||
```
|
||||
|
||||
## Cooming Soon
|
||||
|
||||
### `ollama.search("query")`
|
||||
|
||||
Search for compatible models that Ollama can run
|
||||
|
|
|
@ -27,6 +27,10 @@ def main():
|
|||
add_parser.add_argument("model")
|
||||
add_parser.set_defaults(fn=add)
|
||||
|
||||
pull_parser = subparsers.add_parser("pull")
|
||||
pull_parser.add_argument("remote")
|
||||
pull_parser.set_defaults(fn=pull)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
|
@ -55,3 +59,7 @@ def generate(*args, **kwargs):
|
|||
|
||||
def add(model, models_home):
|
||||
os.rename(model, Path(models_home) / Path(model).name)
|
||||
|
||||
|
||||
def pull(*args, **kwargs):
|
||||
model.pull(*args, **kwargs)
|
||||
|
|
|
@ -1,9 +1,76 @@
|
|||
from os import walk, path
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def models(models_home='.', *args, **kwargs):
|
||||
for root, _, files in walk(models_home):
|
||||
def models(models_home=".", *args, **kwargs):
|
||||
for root, _, files in os.walk(models_home):
|
||||
for file in files:
|
||||
base, ext = path.splitext(file)
|
||||
if ext == '.bin':
|
||||
yield base, path.join(root, file)
|
||||
base, ext = os.path.splitext(file)
|
||||
if ext == ".bin":
|
||||
yield base, os.path.join(root, file)
|
||||
|
||||
|
||||
def pull(remote, models_home=".", *args, **kwargs):
|
||||
if not (remote.startswith("http://") or remote.startswith("https://")):
|
||||
remote = f"https://{remote}"
|
||||
|
||||
parts = urlsplit(remote)
|
||||
path_parts = parts.path.split("/tree/")
|
||||
|
||||
if len(path_parts) == 1:
|
||||
model = path_parts[0]
|
||||
branch = "main"
|
||||
else:
|
||||
model, branch = path_parts
|
||||
|
||||
model = model.strip("/")
|
||||
|
||||
# Reconstruct the URL
|
||||
new_url = urlunsplit(
|
||||
(
|
||||
"https",
|
||||
parts.netloc,
|
||||
f"/api/models/{model}/tree/{branch}",
|
||||
parts.query,
|
||||
parts.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Fetching model from {new_url}")
|
||||
|
||||
response = requests.get(new_url)
|
||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
for file_info in json_response:
|
||||
if file_info.get("type") == "file" and file_info.get("path").endswith(".bin"):
|
||||
f_path = file_info.get("path")
|
||||
download_url = f"https://huggingface.co/{model}/resolve/{branch}/{f_path}"
|
||||
local_filename = os.path.join(
|
||||
models_home, os.path.basename(file_info.get("path"))
|
||||
)
|
||||
|
||||
if os.path.exists(local_filename):
|
||||
# TODO: check if the file is the same
|
||||
break
|
||||
|
||||
response = requests.get(download_url, stream=True)
|
||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(local_filename, "wb") as file, tqdm(
|
||||
desc=local_filename,
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
break # Stop after downloading the first .bin file
|
||||
|
|
86
poetry.lock
generated
86
poetry.lock
generated
|
@ -165,11 +165,22 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
|
|||
tests = ["attrs[tests-no-zope]", "zope-interface"]
|
||||
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.5.7"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"},
|
||||
{file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.1.0"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"},
|
||||
|
@ -249,6 +260,17 @@ files = [
|
|||
{file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diskcache"
|
||||
version = "5.6.1"
|
||||
|
@ -374,7 +396,7 @@ files = [
|
|||
name = "idna"
|
||||
version = "3.4"
|
||||
description = "Internationalized Domain Names in Applications (IDNA)"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
|
||||
|
@ -680,6 +702,27 @@ test = ["coverage", "flaky", "matplotlib", "numpy", "pandas", "pylint (>=2.5.0,<
|
|||
websockets = ["websockets (>=10.3)"]
|
||||
yapf = ["whatthepatch (>=1.0.2,<2.0.0)", "yapf (>=0.33.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.31.0"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
|
||||
{file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=2017.4.17"
|
||||
charset-normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
[package.extras]
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "68.0.0"
|
||||
|
@ -696,6 +739,26 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
|
|||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.65.0"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"},
|
||||
{file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["py-make (>=0.1.0)", "twine", "wheel"]
|
||||
notebook = ["ipywidgets (>=6)"]
|
||||
slack = ["slack-sdk"]
|
||||
telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.6.3"
|
||||
|
@ -777,6 +840,23 @@ files = [
|
|||
{file = "ujson-5.8.0.tar.gz", hash = "sha256:78e318def4ade898a461b3d92a79f9441e7e0e4d2ad5419abed4336d702c7425"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.0.3"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"},
|
||||
{file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
||||
secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.9.2"
|
||||
|
@ -870,4 +950,4 @@ server = ["aiohttp", "aiohttp-cors"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "c649ffbb8b8045d35831f9fb09a6f099b8e940abd85a020c3dfd24173b2582d8"
|
||||
content-hash = "ba168754266c6c46b2136207415a5b3a879c957e53e924cab1e64267849ceb90"
|
||||
|
|
|
@ -13,6 +13,8 @@ llama-cpp-python = "^0.1.66"
|
|||
|
||||
aiohttp = {version = "^3.8.4", optional = true}
|
||||
aiohttp-cors = {version = "^0.7.0", optional = true}
|
||||
requests = "^2.31.0"
|
||||
tqdm = "^4.65.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
server = ["aiohttp", "aiohttp_cors"]
|
||||
|
|
Loading…
Reference in a new issue