resume download of models from directory

This commit is contained in:
Bruce MacDonald 2023-07-04 11:16:50 -04:00
parent 35f202f573
commit a86a4f8c26
2 changed files with 28 additions and 12 deletions

View file

@ -175,12 +175,18 @@ def search(*args, **kwargs):
def pull(*args, **kwargs):
model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
print("Up to date.")
try:
model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
print("Up to date.")
except Exception as e:
print(f"An error occurred: {e}")
def run(*args, **kwargs):
name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
kwargs.update({"model": name})
print(f"Running {name}...")
generate(*args, **kwargs)
try:
name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
kwargs.update({"model": name})
print(f"Running {name}...")
generate(*args, **kwargs)
except Exception as e:
print(f"An error occurred: {e}")

View file

@ -123,16 +123,26 @@ def download_file(download_url, file_name, file_size):
def pull(model_name, *args, **kwargs):
maybe_existing_model_location = MODELS_CACHE_PATH / str(model_name + '.bin')
if path.exists(model_name) or path.exists(maybe_existing_model_location):
# a file on the filesystem is being specified
return model_name
# check the remote model location and see if it needs to be downloaded
url = model_name
file_name = ""
if not validators.url(url) and not url.startswith('huggingface.co'):
url = get_url_from_directory(model_name)
file_name = model_name
try:
url = get_url_from_directory(model_name)
except Exception as e:
# may not have been able to check remote directory, return now
return model_name
if url is model_name:
# this is not a model from our directory, so can't check remote
maybe_existing_model_location = MODELS_CACHE_PATH / str(model_name + '.bin')
if path.exists(model_name) or path.exists(maybe_existing_model_location):
# a file on the filesystem is being specified
return model_name
raise Exception("unknown model")
else:
# this is a model from our directory, check remote
file_name = model_name
if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}'