clean up pull

This commit is contained in:
Bruce MacDonald 2023-06-29 15:06:34 -04:00
parent d57903875e
commit 61e39bf5d9
3 changed files with 57 additions and 45 deletions

View file

@ -22,7 +22,6 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
llm = load(model, models_home=models_home, llms=llms) llm = load(model, models_home=models_home, llms=llms)
prompt = ollama.prompt.template(model, prompt) prompt = ollama.prompt.template(model, prompt)
if "max_tokens" not in kwargs: if "max_tokens" not in kwargs:
kwargs.update({"max_tokens": 16384}) kwargs.update({"max_tokens": 16384})
@ -39,11 +38,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
def load(model, models_home=".", llms={}): def load(model, models_home=".", llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
model_path = { stored_model_path = os.path.join(models_home, model, ".bin")
name: path for name, path in ollama.model.models(models_home) if os.path.exists(stored_model_path):
}.get(model, None) model_path = stored_model_path
else:
if not model_path:
# try loading this as a path to a model, rather than a model name # try loading this as a path to a model, rather than a model name
model_path = os.path.abspath(model) model_path = os.path.abspath(model)

View file

@ -16,26 +16,18 @@ def models(models_home='.', *args, **kwargs):
yield base yield base
def pull(model, models_home='.', *args, **kwargs): # get the url of the model from our curated directory
url = model def get_url_from_directory(model):
if not validators.url(url) and not url.startswith('huggingface.co'):
# this may just be a local model location
if model in models(models_home):
return model
# see if we have this model in our directory
response = requests.get(models_endpoint_url) response = requests.get(models_endpoint_url)
response.raise_for_status() response.raise_for_status()
directory = response.json() directory = response.json()
for model_info in directory: for model_info in directory:
if model_info.get('name') == model: if model_info.get('name') == model:
url = f"https://{model_info.get('url')}" return model_info.get('url')
break return model
if not validators.url(url):
raise Exception(f'Unknown model {model}')
if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}'
def download_from_repo(url, models_home='.'):
parts = urlsplit(url) parts = urlsplit(url)
path_parts = parts.path.split('/tree/') path_parts = parts.path.split('/tree/')
@ -47,7 +39,6 @@ def pull(model, models_home='.', *args, **kwargs):
location = location.strip('/') location = location.strip('/')
# Reconstruct the URL
download_url = urlunsplit( download_url = urlunsplit(
( (
'https', 'https',
@ -57,13 +48,15 @@ def pull(model, models_home='.', *args, **kwargs):
parts.fragment, parts.fragment,
) )
) )
response = requests.get(download_url) response = requests.get(download_url)
response.raise_for_status() # Raises stored HTTPError, if one occurred response.raise_for_status()
json_response = response.json() json_response = response.json()
# get the last bin file we find, this is probably the most up to date download_url, file_size = find_bin_file(json_response, location, branch)
return download_file(download_url, models_home, location, file_size)
def find_bin_file(json_response, location, branch):
download_url = None download_url = None
file_size = 0 file_size = 0
for file_info in json_response: for file_info in json_response:
@ -77,27 +70,25 @@ def pull(model, models_home='.', *args, **kwargs):
if download_url is None: if download_url is None:
raise Exception('No model found') raise Exception('No model found')
local_filename = os.path.join(models_home, os.path.basename(url)) + '.bin' return download_url, file_size
# Check if file already exists
first_byte = 0 def download_file(download_url, models_home, location, file_size):
if os.path.exists(local_filename): local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
# TODO: check if the file is the same SHA
first_byte = os.path.getsize(local_filename) first_byte = (
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
)
if first_byte >= file_size: if first_byte >= file_size:
return local_filename return local_filename
print(f'Pulling {model}...') print(f'Pulling {os.path.basename(location)}...')
# If file size is non-zero, resume download header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
if first_byte != 0:
header = {'Range': f'bytes={first_byte}-'}
else:
header = {}
response = requests.get(download_url, headers=header, stream=True) response = requests.get(download_url, headers=header, stream=True)
response.raise_for_status() # Raises stored HTTPError, if one occurred response.raise_for_status()
total_size = int(response.headers.get('content-length', 0)) total_size = int(response.headers.get('content-length', 0))
@ -115,3 +106,26 @@ def pull(model, models_home='.', *args, **kwargs):
bar.update(size) bar.update(size)
return local_filename return local_filename
def pull(model, models_home='.', *args, **kwargs):
if os.path.exists(model):
# a file on the filesystem is being specified
return model
# check the remote model location and see if it needs to be downloaded
url = model
if not validators.url(url) and not url.startswith('huggingface.co'):
url = get_url_from_directory(model)
if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}'
if not validators.url(url):
if model in models(models_home):
# the model is already downloaded, and specified by name
return model
raise Exception(f'Unknown model {model}')
local_filename = download_from_repo(url, models_home)
return local_filename

View file

@ -1,4 +1,4 @@
from os import path import os
from difflib import SequenceMatcher from difflib import SequenceMatcher
from jinja2 import Environment, PackageLoader from jinja2 import Environment, PackageLoader
@ -9,8 +9,8 @@ def template(model, prompt):
environment = Environment(loader=PackageLoader(__name__, 'templates')) environment = Environment(loader=PackageLoader(__name__, 'templates'))
for template in environment.list_templates(): for template in environment.list_templates():
base, _ = path.splitext(template) base, _ = os.path.splitext(template)
ratio = SequenceMatcher(None, model.lower(), base).ratio() ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio()
if ratio > best_ratio: if ratio > best_ratio:
best_ratio = ratio best_ratio = ratio
best_template = template best_template = template