consistency between generate and add naming

This commit is contained in:
Bruce MacDonald 2023-06-29 18:22:45 -04:00
parent 8fc8a00752
commit 01c31aac78
3 changed files with 38 additions and 30 deletions

View file

@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs):
spinner = yaspin() spinner = yaspin()
spinner.start() spinner.start()
spinner_running = True spinner_running = True
for output in engine.generate(*args, **kwargs): try:
choices = output.get("choices", []) for output in engine.generate(*args, **kwargs):
if len(choices) > 0: choices = output.get("choices", [])
if spinner_running: if len(choices) > 0:
spinner.stop() if spinner_running:
spinner_running = False spinner.stop()
print("\r", end="") # move cursor back to beginning of line again spinner_running = False
print(choices[0].get("text", ""), end="", flush=True) print("\r", end="") # move cursor back to beginning of line again
print(choices[0].get("text", ""), end="", flush=True)
except Exception:
spinner.stop()
raise
# end with a new line # end with a new line
print(flush=True) print(flush=True)

View file

@ -1,5 +1,4 @@
import os from os import path, dup, dup2, devnull
import json
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from llama_cpp import Llama as LLM from llama_cpp import Llama as LLM
@ -10,12 +9,12 @@ import ollama.prompt
@contextmanager @contextmanager
def suppress_stderr(): def suppress_stderr():
stderr = os.dup(sys.stderr.fileno()) stderr = dup(sys.stderr.fileno())
with open(os.devnull, "w") as devnull: with open(devnull, "w") as devnull:
os.dup2(devnull.fileno(), sys.stderr.fileno()) dup2(devnull.fileno(), sys.stderr.fileno())
yield yield
os.dup2(stderr, sys.stderr.fileno()) dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
def load(model, models_home=".", llms={}): def load(model, models_home=".", llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
stored_model_path = os.path.join(models_home, model, ".bin") stored_model_path = path.join(models_home, model) + ".bin"
if os.path.exists(stored_model_path): if path.exists(stored_model_path):
model_path = stored_model_path model_path = stored_model_path
else: else:
# try loading this as a path to a model, rather than a model name # try loading this as a path to a model, rather than a model name
model_path = os.path.abspath(model) model_path = path.abspath(model)
if not path.exists(model_path):
raise Exception(f"Model not found: {model}")
try: try:
# suppress LLM's output # suppress LLM's output

View file

@ -1,6 +1,6 @@
import os
import requests import requests
import validators import validators
from os import path, walk
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm from tqdm import tqdm
@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
def models(models_home='.', *args, **kwargs): def models(models_home='.', *args, **kwargs):
for _, _, files in os.walk(models_home): for _, _, files in walk(models_home):
for file in files: for file in files:
base, ext = os.path.splitext(file) base, ext = path.splitext(file)
if ext == '.bin': if ext == '.bin':
yield base yield base
@ -27,7 +27,7 @@ def get_url_from_directory(model):
return model return model
def download_from_repo(url, models_home='.'): def download_from_repo(url, file_name, models_home='.'):
parts = urlsplit(url) parts = urlsplit(url)
path_parts = parts.path.split('/tree/') path_parts = parts.path.split('/tree/')
@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
location, branch = path_parts location, branch = path_parts
location = location.strip('/') location = location.strip('/')
if file_name == '':
file_name = path.basename(location)
download_url = urlunsplit( download_url = urlunsplit(
( (
@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
json_response = response.json() json_response = response.json()
download_url, file_size = find_bin_file(json_response, location, branch) download_url, file_size = find_bin_file(json_response, location, branch)
return download_file(download_url, models_home, location, file_size) return download_file(download_url, models_home, file_name, file_size)
def find_bin_file(json_response, location, branch): def find_bin_file(json_response, location, branch):
@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
return download_url, file_size return download_url, file_size
def download_file(download_url, models_home, location, file_size): def download_file(download_url, models_home, file_name, file_size):
local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin' local_filename = path.join(models_home, file_name) + '.bin'
first_byte = ( first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
)
if first_byte >= file_size: if first_byte >= file_size:
return local_filename return local_filename
print(f'Pulling {os.path.basename(location)}...') print(f'Pulling {file_name}...')
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {} header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
def pull(model, models_home='.', *args, **kwargs): def pull(model, models_home='.', *args, **kwargs):
if os.path.exists(model): if path.exists(model):
# a file on the filesystem is being specified # a file on the filesystem is being specified
return model return model
# check the remote model location and see if it needs to be downloaded # check the remote model location and see if it needs to be downloaded
url = model url = model
file_name = ""
if not validators.url(url) and not url.startswith('huggingface.co'): if not validators.url(url) and not url.startswith('huggingface.co'):
url = get_url_from_directory(model) url = get_url_from_directory(model)
file_name = model
if not (url.startswith('http://') or url.startswith('https://')): if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}' url = f'https://{url}'
@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
return model return model
raise Exception(f'Unknown model {model}') raise Exception(f'Unknown model {model}')
local_filename = download_from_repo(url, models_home) local_filename = download_from_repo(url, file_name, models_home)
return local_filename return local_filename