clean up pull
This commit is contained in:
parent
d57903875e
commit
61e39bf5d9
3 changed files with 57 additions and 45 deletions
|
@ -22,7 +22,6 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
||||||
llm = load(model, models_home=models_home, llms=llms)
|
llm = load(model, models_home=models_home, llms=llms)
|
||||||
|
|
||||||
prompt = ollama.prompt.template(model, prompt)
|
prompt = ollama.prompt.template(model, prompt)
|
||||||
|
|
||||||
if "max_tokens" not in kwargs:
|
if "max_tokens" not in kwargs:
|
||||||
kwargs.update({"max_tokens": 16384})
|
kwargs.update({"max_tokens": 16384})
|
||||||
|
|
||||||
|
@ -39,11 +38,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
||||||
def load(model, models_home=".", llms={}):
|
def load(model, models_home=".", llms={}):
|
||||||
llm = llms.get(model, None)
|
llm = llms.get(model, None)
|
||||||
if not llm:
|
if not llm:
|
||||||
model_path = {
|
stored_model_path = os.path.join(models_home, model, ".bin")
|
||||||
name: path for name, path in ollama.model.models(models_home)
|
if os.path.exists(stored_model_path):
|
||||||
}.get(model, None)
|
model_path = stored_model_path
|
||||||
|
else:
|
||||||
if not model_path:
|
|
||||||
# try loading this as a path to a model, rather than a model name
|
# try loading this as a path to a model, rather than a model name
|
||||||
model_path = os.path.abspath(model)
|
model_path = os.path.abspath(model)
|
||||||
|
|
||||||
|
|
|
@ -16,26 +16,18 @@ def models(models_home='.', *args, **kwargs):
|
||||||
yield base
|
yield base
|
||||||
|
|
||||||
|
|
||||||
def pull(model, models_home='.', *args, **kwargs):
|
# get the url of the model from our curated directory
|
||||||
url = model
|
def get_url_from_directory(model):
|
||||||
if not validators.url(url) and not url.startswith('huggingface.co'):
|
|
||||||
# this may just be a local model location
|
|
||||||
if model in models(models_home):
|
|
||||||
return model
|
|
||||||
# see if we have this model in our directory
|
|
||||||
response = requests.get(models_endpoint_url)
|
response = requests.get(models_endpoint_url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
directory = response.json()
|
directory = response.json()
|
||||||
for model_info in directory:
|
for model_info in directory:
|
||||||
if model_info.get('name') == model:
|
if model_info.get('name') == model:
|
||||||
url = f"https://{model_info.get('url')}"
|
return model_info.get('url')
|
||||||
break
|
return model
|
||||||
if not validators.url(url):
|
|
||||||
raise Exception(f'Unknown model {model}')
|
|
||||||
|
|
||||||
if not (url.startswith('http://') or url.startswith('https://')):
|
|
||||||
url = f'https://{url}'
|
|
||||||
|
|
||||||
|
def download_from_repo(url, models_home='.'):
|
||||||
parts = urlsplit(url)
|
parts = urlsplit(url)
|
||||||
path_parts = parts.path.split('/tree/')
|
path_parts = parts.path.split('/tree/')
|
||||||
|
|
||||||
|
@ -47,7 +39,6 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
|
|
||||||
location = location.strip('/')
|
location = location.strip('/')
|
||||||
|
|
||||||
# Reconstruct the URL
|
|
||||||
download_url = urlunsplit(
|
download_url = urlunsplit(
|
||||||
(
|
(
|
||||||
'https',
|
'https',
|
||||||
|
@ -57,13 +48,15 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
parts.fragment,
|
parts.fragment,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = requests.get(download_url)
|
response = requests.get(download_url)
|
||||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
response.raise_for_status()
|
||||||
|
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
|
|
||||||
# get the last bin file we find, this is probably the most up to date
|
download_url, file_size = find_bin_file(json_response, location, branch)
|
||||||
|
return download_file(download_url, models_home, location, file_size)
|
||||||
|
|
||||||
|
|
||||||
|
def find_bin_file(json_response, location, branch):
|
||||||
download_url = None
|
download_url = None
|
||||||
file_size = 0
|
file_size = 0
|
||||||
for file_info in json_response:
|
for file_info in json_response:
|
||||||
|
@ -77,27 +70,25 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
if download_url is None:
|
if download_url is None:
|
||||||
raise Exception('No model found')
|
raise Exception('No model found')
|
||||||
|
|
||||||
local_filename = os.path.join(models_home, os.path.basename(url)) + '.bin'
|
return download_url, file_size
|
||||||
|
|
||||||
# Check if file already exists
|
|
||||||
first_byte = 0
|
def download_file(download_url, models_home, location, file_size):
|
||||||
if os.path.exists(local_filename):
|
local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
|
||||||
# TODO: check if the file is the same SHA
|
|
||||||
first_byte = os.path.getsize(local_filename)
|
first_byte = (
|
||||||
|
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
|
||||||
|
)
|
||||||
|
|
||||||
if first_byte >= file_size:
|
if first_byte >= file_size:
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
||||||
print(f'Pulling {model}...')
|
print(f'Pulling {os.path.basename(location)}...')
|
||||||
|
|
||||||
# If file size is non-zero, resume download
|
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
||||||
if first_byte != 0:
|
|
||||||
header = {'Range': f'bytes={first_byte}-'}
|
|
||||||
else:
|
|
||||||
header = {}
|
|
||||||
|
|
||||||
response = requests.get(download_url, headers=header, stream=True)
|
response = requests.get(download_url, headers=header, stream=True)
|
||||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
response.raise_for_status()
|
||||||
|
|
||||||
total_size = int(response.headers.get('content-length', 0))
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
|
||||||
|
@ -115,3 +106,26 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
|
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
||||||
|
|
||||||
|
def pull(model, models_home='.', *args, **kwargs):
|
||||||
|
if os.path.exists(model):
|
||||||
|
# a file on the filesystem is being specified
|
||||||
|
return model
|
||||||
|
# check the remote model location and see if it needs to be downloaded
|
||||||
|
url = model
|
||||||
|
if not validators.url(url) and not url.startswith('huggingface.co'):
|
||||||
|
url = get_url_from_directory(model)
|
||||||
|
|
||||||
|
if not (url.startswith('http://') or url.startswith('https://')):
|
||||||
|
url = f'https://{url}'
|
||||||
|
|
||||||
|
if not validators.url(url):
|
||||||
|
if model in models(models_home):
|
||||||
|
# the model is already downloaded, and specified by name
|
||||||
|
return model
|
||||||
|
raise Exception(f'Unknown model {model}')
|
||||||
|
|
||||||
|
local_filename = download_from_repo(url, models_home)
|
||||||
|
|
||||||
|
return local_filename
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from os import path
|
import os
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from jinja2 import Environment, PackageLoader
|
from jinja2 import Environment, PackageLoader
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ def template(model, prompt):
|
||||||
|
|
||||||
environment = Environment(loader=PackageLoader(__name__, 'templates'))
|
environment = Environment(loader=PackageLoader(__name__, 'templates'))
|
||||||
for template in environment.list_templates():
|
for template in environment.list_templates():
|
||||||
base, _ = path.splitext(template)
|
base, _ = os.path.splitext(template)
|
||||||
ratio = SequenceMatcher(None, model.lower(), base).ratio()
|
ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio()
|
||||||
if ratio > best_ratio:
|
if ratio > best_ratio:
|
||||||
best_ratio = ratio
|
best_ratio = ratio
|
||||||
best_template = template
|
best_template = template
|
||||||
|
|
Loading…
Add table
Reference in a new issue