update python client create example (#1227)

* add remote create to python example client
This commit is contained in:
Bruce MacDonald 2023-11-27 15:36:19 -05:00 committed by GitHub
parent 39c6d949fc
commit 928950fcc6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,10 @@
import os import os
import json import json
import requests import requests
import os
import hashlib
import json
from pathlib import Path
BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434') BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
@ -58,29 +62,85 @@ def generate(model_name, prompt, system=None, template=None, format="", context=
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None, None return None, None
# Create a blob file on the server if it doesn't exist.
def create_blob(digest, file_path):
url = f"{BASE_URL}/api/blobs/{digest}"
# Check if the blob exists
response = requests.head(url)
if response.status_code != 404:
return # Blob already exists, no need to upload
response.raise_for_status()
# Upload the blob
with open(file_path, 'rb') as file_data:
requests.post(url, data=file_data)
# Create a model from a Modelfile. Use the callback function to override the default handler. # Create a model from a Modelfile. Use the callback function to override the default handler.
def create(model_name, model_path, callback=None): def create(model_name, filename, callback=None):
try: try:
file_path = Path(filename).expanduser().resolve()
processed_lines = []
# Read and process the modelfile
with open(file_path, 'r') as f:
for line in f:
# Skip empty or whitespace-only lines
if not line.strip():
continue
command, args = line.split(maxsplit=1)
if command.upper() in ["FROM", "ADAPTER"]:
path = Path(args.strip()).expanduser()
# Check if path is relative and resolve it
if not path.is_absolute():
path = (file_path.parent / path)
# Skip if file does not exist for "model", this is handled by the server
if not path.exists():
processed_lines.append(line)
continue
# Calculate SHA-256 hash
with open(path, 'rb') as bin_file:
hash = hashlib.sha256()
hash.update(bin_file.read())
blob = f"sha256:{hash.hexdigest()}"
# Add the file to the remote server
create_blob(blob, path)
# Replace path with digest in the line
line = f"{command} @{blob}\n"
processed_lines.append(line)
# Combine processed lines back into a single string
modelfile_content = '\n'.join(processed_lines)
url = f"{BASE_URL}/api/create" url = f"{BASE_URL}/api/create"
payload = {"name": model_name, "path": model_path} payload = {"name": model_name, "modelfile": modelfile_content}
# Making a POST request with the stream parameter set to True to handle streaming responses # Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response: with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status() response.raise_for_status()
# Iterating over the response line by line and displaying the status # Iterating over the response line by line and displaying the status
for line in response.iter_lines(): for line in response.iter_lines():
if line: if line:
# Parsing each line (JSON chunk) and extracting the status
chunk = json.loads(line) chunk = json.loads(line)
if callback: if callback:
callback(chunk) callback(chunk)
else: else:
print(f"Status: {chunk.get('status')}") print(f"Status: {chunk.get('status')}")
except requests.exceptions.RequestException as e:
except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
# Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple # Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple
# calls to will share the same download progress. Use the callback function to override the default handler. # calls to will share the same download progress. Use the callback function to override the default handler.
def pull(model_name, insecure=False, callback=None): def pull(model_name, insecure=False, callback=None):
@ -222,5 +282,3 @@ def heartbeat():
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return "Ollama is not running" return "Ollama is not running"