consistency between generate and add naming
This commit is contained in:
parent
8fc8a00752
commit
01c31aac78
3 changed files with 38 additions and 30 deletions
|
@ -79,6 +79,7 @@ def generate_oneshot(*args, **kwargs):
|
||||||
spinner = yaspin()
|
spinner = yaspin()
|
||||||
spinner.start()
|
spinner.start()
|
||||||
spinner_running = True
|
spinner_running = True
|
||||||
|
try:
|
||||||
for output in engine.generate(*args, **kwargs):
|
for output in engine.generate(*args, **kwargs):
|
||||||
choices = output.get("choices", [])
|
choices = output.get("choices", [])
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
|
@ -87,6 +88,9 @@ def generate_oneshot(*args, **kwargs):
|
||||||
spinner_running = False
|
spinner_running = False
|
||||||
print("\r", end="") # move cursor back to beginning of line again
|
print("\r", end="") # move cursor back to beginning of line again
|
||||||
print(choices[0].get("text", ""), end="", flush=True)
|
print(choices[0].get("text", ""), end="", flush=True)
|
||||||
|
except Exception:
|
||||||
|
spinner.stop()
|
||||||
|
raise
|
||||||
|
|
||||||
# end with a new line
|
# end with a new line
|
||||||
print(flush=True)
|
print(flush=True)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
from os import path, dup, dup2, devnull
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from llama_cpp import Llama as LLM
|
from llama_cpp import Llama as LLM
|
||||||
|
@ -10,12 +9,12 @@ import ollama.prompt
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def suppress_stderr():
|
def suppress_stderr():
|
||||||
stderr = os.dup(sys.stderr.fileno())
|
stderr = dup(sys.stderr.fileno())
|
||||||
with open(os.devnull, "w") as devnull:
|
with open(devnull, "w") as devnull:
|
||||||
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
dup2(devnull.fileno(), sys.stderr.fileno())
|
||||||
yield
|
yield
|
||||||
|
|
||||||
os.dup2(stderr, sys.stderr.fileno())
|
dup2(stderr, sys.stderr.fileno())
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
||||||
|
@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
|
||||||
def load(model, models_home=".", llms={}):
|
def load(model, models_home=".", llms={}):
|
||||||
llm = llms.get(model, None)
|
llm = llms.get(model, None)
|
||||||
if not llm:
|
if not llm:
|
||||||
stored_model_path = os.path.join(models_home, model, ".bin")
|
stored_model_path = path.join(models_home, model) + ".bin"
|
||||||
if os.path.exists(stored_model_path):
|
if path.exists(stored_model_path):
|
||||||
model_path = stored_model_path
|
model_path = stored_model_path
|
||||||
else:
|
else:
|
||||||
# try loading this as a path to a model, rather than a model name
|
# try loading this as a path to a model, rather than a model name
|
||||||
model_path = os.path.abspath(model)
|
model_path = path.abspath(model)
|
||||||
|
|
||||||
|
if not path.exists(model_path):
|
||||||
|
raise Exception(f"Model not found: {model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# suppress LLM's output
|
# suppress LLM's output
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
|
||||||
import requests
|
import requests
|
||||||
import validators
|
import validators
|
||||||
|
from os import path, walk
|
||||||
from urllib.parse import urlsplit, urlunsplit
|
from urllib.parse import urlsplit, urlunsplit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
|
||||||
|
|
||||||
|
|
||||||
def models(models_home='.', *args, **kwargs):
|
def models(models_home='.', *args, **kwargs):
|
||||||
for _, _, files in os.walk(models_home):
|
for _, _, files in walk(models_home):
|
||||||
for file in files:
|
for file in files:
|
||||||
base, ext = os.path.splitext(file)
|
base, ext = path.splitext(file)
|
||||||
if ext == '.bin':
|
if ext == '.bin':
|
||||||
yield base
|
yield base
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ def get_url_from_directory(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def download_from_repo(url, models_home='.'):
|
def download_from_repo(url, file_name, models_home='.'):
|
||||||
parts = urlsplit(url)
|
parts = urlsplit(url)
|
||||||
path_parts = parts.path.split('/tree/')
|
path_parts = parts.path.split('/tree/')
|
||||||
|
|
||||||
|
@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
|
||||||
location, branch = path_parts
|
location, branch = path_parts
|
||||||
|
|
||||||
location = location.strip('/')
|
location = location.strip('/')
|
||||||
|
if file_name == '':
|
||||||
|
file_name = path.basename(location)
|
||||||
|
|
||||||
download_url = urlunsplit(
|
download_url = urlunsplit(
|
||||||
(
|
(
|
||||||
|
@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
|
|
||||||
download_url, file_size = find_bin_file(json_response, location, branch)
|
download_url, file_size = find_bin_file(json_response, location, branch)
|
||||||
return download_file(download_url, models_home, location, file_size)
|
return download_file(download_url, models_home, file_name, file_size)
|
||||||
|
|
||||||
|
|
||||||
def find_bin_file(json_response, location, branch):
|
def find_bin_file(json_response, location, branch):
|
||||||
|
@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
|
||||||
return download_url, file_size
|
return download_url, file_size
|
||||||
|
|
||||||
|
|
||||||
def download_file(download_url, models_home, location, file_size):
|
def download_file(download_url, models_home, file_name, file_size):
|
||||||
local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
|
local_filename = path.join(models_home, file_name) + '.bin'
|
||||||
|
|
||||||
first_byte = (
|
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
||||||
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if first_byte >= file_size:
|
if first_byte >= file_size:
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
||||||
print(f'Pulling {os.path.basename(location)}...')
|
print(f'Pulling {file_name}...')
|
||||||
|
|
||||||
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
||||||
|
|
||||||
|
@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
|
||||||
|
|
||||||
|
|
||||||
def pull(model, models_home='.', *args, **kwargs):
|
def pull(model, models_home='.', *args, **kwargs):
|
||||||
if os.path.exists(model):
|
if path.exists(model):
|
||||||
# a file on the filesystem is being specified
|
# a file on the filesystem is being specified
|
||||||
return model
|
return model
|
||||||
# check the remote model location and see if it needs to be downloaded
|
# check the remote model location and see if it needs to be downloaded
|
||||||
url = model
|
url = model
|
||||||
|
file_name = ""
|
||||||
if not validators.url(url) and not url.startswith('huggingface.co'):
|
if not validators.url(url) and not url.startswith('huggingface.co'):
|
||||||
url = get_url_from_directory(model)
|
url = get_url_from_directory(model)
|
||||||
|
file_name = model
|
||||||
|
|
||||||
if not (url.startswith('http://') or url.startswith('https://')):
|
if not (url.startswith('http://') or url.startswith('https://')):
|
||||||
url = f'https://{url}'
|
url = f'https://{url}'
|
||||||
|
@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
|
||||||
return model
|
return model
|
||||||
raise Exception(f'Unknown model {model}')
|
raise Exception(f'Unknown model {model}')
|
||||||
|
|
||||||
local_filename = download_from_repo(url, models_home)
|
local_filename = download_from_repo(url, file_name, models_home)
|
||||||
|
|
||||||
return local_filename
|
return local_filename
|
||||||
|
|
Loading…
Reference in a new issue