use all caps for constants
This commit is contained in:
parent
4d0eb7639a
commit
07d8d56177
3 changed files with 9 additions and 9 deletions
|
@ -31,7 +31,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# create models home if it doesn't exist
|
# create models home if it doesn't exist
|
||||||
os.makedirs(model.models_home, exist_ok=True)
|
os.makedirs(model.MODELS_CACHE_PATH, exist_ok=True)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(
|
subparsers = parser.add_subparsers(
|
||||||
title='commands',
|
title='commands',
|
||||||
|
|
|
@ -7,7 +7,7 @@ from llama_cpp import Llama
|
||||||
from ctransformers import AutoModelForCausalLM
|
from ctransformers import AutoModelForCausalLM
|
||||||
|
|
||||||
import ollama.prompt
|
import ollama.prompt
|
||||||
from ollama.model import models_home
|
from ollama.model import MODELS_CACHE_PATH
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -30,7 +30,7 @@ def load(model_name, models={}):
|
||||||
if not models.get(model_name, None):
|
if not models.get(model_name, None):
|
||||||
model_path = path.expanduser(model_name)
|
model_path = path.expanduser(model_name)
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
model_path = path.join(models_home, model_name + ".bin")
|
model_path = path.join(MODELS_CACHE_PATH, model_name + ".bin")
|
||||||
|
|
||||||
runners = {
|
runners = {
|
||||||
model_type: cls
|
model_type: cls
|
||||||
|
|
|
@ -6,12 +6,12 @@ from urllib.parse import urlsplit, urlunsplit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
models_endpoint_url = 'https://ollama.ai/api/models'
|
MODELS_MANIFEST = 'https://ollama.ai/api/models'
|
||||||
models_home = path.join(Path.home(), '.ollama', 'models')
|
MODELS_CACHE_PATH = path.join(Path.home(), '.ollama', 'models')
|
||||||
|
|
||||||
|
|
||||||
def models(*args, **kwargs):
|
def models(*args, **kwargs):
|
||||||
for _, _, files in walk(models_home):
|
for _, _, files in walk(MODELS_CACHE_PATH):
|
||||||
for file in files:
|
for file in files:
|
||||||
base, ext = path.splitext(file)
|
base, ext = path.splitext(file)
|
||||||
if ext == '.bin':
|
if ext == '.bin':
|
||||||
|
@ -20,7 +20,7 @@ def models(*args, **kwargs):
|
||||||
|
|
||||||
# get the url of the model from our curated directory
|
# get the url of the model from our curated directory
|
||||||
def get_url_from_directory(model):
|
def get_url_from_directory(model):
|
||||||
response = requests.get(models_endpoint_url)
|
response = requests.get(MODELS_MANIFEST)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
directory = response.json()
|
directory = response.json()
|
||||||
for model_info in directory:
|
for model_info in directory:
|
||||||
|
@ -78,7 +78,7 @@ def find_bin_file(json_response, location, branch):
|
||||||
|
|
||||||
|
|
||||||
def download_file(download_url, file_name, file_size):
|
def download_file(download_url, file_name, file_size):
|
||||||
local_filename = path.join(models_home, file_name) + '.bin'
|
local_filename = path.join(MODELS_CACHE_PATH, file_name) + '.bin'
|
||||||
|
|
||||||
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ def pull(model, *args, **kwargs):
|
||||||
url = f'https://{url}'
|
url = f'https://{url}'
|
||||||
|
|
||||||
if not validators.url(url):
|
if not validators.url(url):
|
||||||
if model in models(models_home):
|
if model in models(MODELS_CACHE_PATH):
|
||||||
# the model is already downloaded, and specified by name
|
# the model is already downloaded, and specified by name
|
||||||
return model
|
return model
|
||||||
raise Exception(f'Unknown model {model}')
|
raise Exception(f'Unknown model {model}')
|
||||||
|
|
Loading…
Reference in a new issue