clean up pull
This commit is contained in:
parent
d57903875e
commit
61e39bf5d9
3 changed files with 57 additions and 45 deletions
|
@ -22,7 +22,6 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
|||
llm = load(model, models_home=models_home, llms=llms)
|
||||
|
||||
prompt = ollama.prompt.template(model, prompt)
|
||||
|
||||
if "max_tokens" not in kwargs:
|
||||
kwargs.update({"max_tokens": 16384})
|
||||
|
||||
|
@ -39,11 +38,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
|||
def load(model, models_home=".", llms={}):
|
||||
llm = llms.get(model, None)
|
||||
if not llm:
|
||||
model_path = {
|
||||
name: path for name, path in ollama.model.models(models_home)
|
||||
}.get(model, None)
|
||||
|
||||
if not model_path:
|
||||
stored_model_path = os.path.join(models_home, model, ".bin")
|
||||
if os.path.exists(stored_model_path):
|
||||
model_path = stored_model_path
|
||||
else:
|
||||
# try loading this as a path to a model, rather than a model name
|
||||
model_path = os.path.abspath(model)
|
||||
|
||||
|
|
|
@ -16,26 +16,18 @@ def models(models_home='.', *args, **kwargs):
|
|||
yield base
|
||||
|
||||
|
||||
def pull(model, models_home='.', *args, **kwargs):
|
||||
url = model
|
||||
if not validators.url(url) and not url.startswith('huggingface.co'):
|
||||
# this may just be a local model location
|
||||
if model in models(models_home):
|
||||
return model
|
||||
# see if we have this model in our directory
|
||||
# get the url of the model from our curated directory
|
||||
def get_url_from_directory(model):
|
||||
response = requests.get(models_endpoint_url)
|
||||
response.raise_for_status()
|
||||
directory = response.json()
|
||||
for model_info in directory:
|
||||
if model_info.get('name') == model:
|
||||
url = f"https://{model_info.get('url')}"
|
||||
break
|
||||
if not validators.url(url):
|
||||
raise Exception(f'Unknown model {model}')
|
||||
return model_info.get('url')
|
||||
return model
|
||||
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
def download_from_repo(url, models_home='.'):
|
||||
parts = urlsplit(url)
|
||||
path_parts = parts.path.split('/tree/')
|
||||
|
||||
|
@ -47,7 +39,6 @@ def pull(model, models_home='.', *args, **kwargs):
|
|||
|
||||
location = location.strip('/')
|
||||
|
||||
# Reconstruct the URL
|
||||
download_url = urlunsplit(
|
||||
(
|
||||
'https',
|
||||
|
@ -57,13 +48,15 @@ def pull(model, models_home='.', *args, **kwargs):
|
|||
parts.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
response = requests.get(download_url)
|
||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
||||
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
|
||||
# get the last bin file we find, this is probably the most up to date
|
||||
download_url, file_size = find_bin_file(json_response, location, branch)
|
||||
return download_file(download_url, models_home, location, file_size)
|
||||
|
||||
|
||||
def find_bin_file(json_response, location, branch):
|
||||
download_url = None
|
||||
file_size = 0
|
||||
for file_info in json_response:
|
||||
|
@ -77,27 +70,25 @@ def pull(model, models_home='.', *args, **kwargs):
|
|||
if download_url is None:
|
||||
raise Exception('No model found')
|
||||
|
||||
local_filename = os.path.join(models_home, os.path.basename(url)) + '.bin'
|
||||
return download_url, file_size
|
||||
|
||||
# Check if file already exists
|
||||
first_byte = 0
|
||||
if os.path.exists(local_filename):
|
||||
# TODO: check if the file is the same SHA
|
||||
first_byte = os.path.getsize(local_filename)
|
||||
|
||||
def download_file(download_url, models_home, location, file_size):
|
||||
local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
|
||||
|
||||
first_byte = (
|
||||
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
|
||||
)
|
||||
|
||||
if first_byte >= file_size:
|
||||
return local_filename
|
||||
|
||||
print(f'Pulling {model}...')
|
||||
print(f'Pulling {os.path.basename(location)}...')
|
||||
|
||||
# If file size is non-zero, resume download
|
||||
if first_byte != 0:
|
||||
header = {'Range': f'bytes={first_byte}-'}
|
||||
else:
|
||||
header = {}
|
||||
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
||||
|
||||
response = requests.get(download_url, headers=header, stream=True)
|
||||
response.raise_for_status() # Raises stored HTTPError, if one occurred
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
|
@ -115,3 +106,26 @@ def pull(model, models_home='.', *args, **kwargs):
|
|||
bar.update(size)
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
def pull(model, models_home='.', *args, **kwargs):
|
||||
if os.path.exists(model):
|
||||
# a file on the filesystem is being specified
|
||||
return model
|
||||
# check the remote model location and see if it needs to be downloaded
|
||||
url = model
|
||||
if not validators.url(url) and not url.startswith('huggingface.co'):
|
||||
url = get_url_from_directory(model)
|
||||
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
if not validators.url(url):
|
||||
if model in models(models_home):
|
||||
# the model is already downloaded, and specified by name
|
||||
return model
|
||||
raise Exception(f'Unknown model {model}')
|
||||
|
||||
local_filename = download_from_repo(url, models_home)
|
||||
|
||||
return local_filename
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from os import path
|
||||
import os
|
||||
from difflib import SequenceMatcher
|
||||
from jinja2 import Environment, PackageLoader
|
||||
|
||||
|
@ -9,8 +9,8 @@ def template(model, prompt):
|
|||
|
||||
environment = Environment(loader=PackageLoader(__name__, 'templates'))
|
||||
for template in environment.list_templates():
|
||||
base, _ = path.splitext(template)
|
||||
ratio = SequenceMatcher(None, model.lower(), base).ratio()
|
||||
base, _ = os.path.splitext(template)
|
||||
ratio = SequenceMatcher(None, os.path.basename(model.lower()), base).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_template = template
|
||||
|
|
Loading…
Reference in a new issue