use all caps for constants

This commit is contained in:
Michael Yang 2023-06-30 10:35:26 -07:00
parent 4d0eb7639a
commit 07d8d56177
3 changed files with 9 additions and 9 deletions

View file

@ -31,7 +31,7 @@ def main():
) )
# create models home if it doesn't exist # create models home if it doesn't exist
os.makedirs(model.models_home, exist_ok=True) os.makedirs(model.MODELS_CACHE_PATH, exist_ok=True)
subparsers = parser.add_subparsers( subparsers = parser.add_subparsers(
title='commands', title='commands',

View file

@ -7,7 +7,7 @@ from llama_cpp import Llama
from ctransformers import AutoModelForCausalLM from ctransformers import AutoModelForCausalLM
import ollama.prompt import ollama.prompt
from ollama.model import models_home from ollama.model import MODELS_CACHE_PATH
@contextmanager @contextmanager
@ -30,7 +30,7 @@ def load(model_name, models={}):
if not models.get(model_name, None): if not models.get(model_name, None):
model_path = path.expanduser(model_name) model_path = path.expanduser(model_name)
if not path.exists(model_path): if not path.exists(model_path):
model_path = path.join(models_home, model_name + ".bin") model_path = path.join(MODELS_CACHE_PATH, model_name + ".bin")
runners = { runners = {
model_type: cls model_type: cls

View file

@ -6,12 +6,12 @@ from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm from tqdm import tqdm
models_endpoint_url = 'https://ollama.ai/api/models' MODELS_MANIFEST = 'https://ollama.ai/api/models'
models_home = path.join(Path.home(), '.ollama', 'models') MODELS_CACHE_PATH = path.join(Path.home(), '.ollama', 'models')
def models(*args, **kwargs): def models(*args, **kwargs):
for _, _, files in walk(models_home): for _, _, files in walk(MODELS_CACHE_PATH):
for file in files: for file in files:
base, ext = path.splitext(file) base, ext = path.splitext(file)
if ext == '.bin': if ext == '.bin':
@ -20,7 +20,7 @@ def models(*args, **kwargs):
# get the url of the model from our curated directory # get the url of the model from our curated directory
def get_url_from_directory(model): def get_url_from_directory(model):
response = requests.get(models_endpoint_url) response = requests.get(MODELS_MANIFEST)
response.raise_for_status() response.raise_for_status()
directory = response.json() directory = response.json()
for model_info in directory: for model_info in directory:
@ -78,7 +78,7 @@ def find_bin_file(json_response, location, branch):
def download_file(download_url, file_name, file_size): def download_file(download_url, file_name, file_size):
local_filename = path.join(models_home, file_name) + '.bin' local_filename = path.join(MODELS_CACHE_PATH, file_name) + '.bin'
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
@ -125,7 +125,7 @@ def pull(model, *args, **kwargs):
url = f'https://{url}' url = f'https://{url}'
if not validators.url(url): if not validators.url(url):
if model in models(models_home): if model in models(MODELS_CACHE_PATH):
# the model is already downloaded, and specified by name # the model is already downloaded, and specified by name
return model return model
raise Exception(f'Unknown model {model}') raise Exception(f'Unknown model {model}')