ollama/api/client.go

315 lines
7.2 KiB
Go
Raw Normal View History

package api
import (
2023-07-04 04:47:00 +00:00
"bufio"
"bytes"
"context"
"encoding/json"
2023-11-14 22:07:40 +00:00
"errors"
2023-07-07 21:04:43 +00:00
"fmt"
2023-07-18 16:09:45 +00:00
"io"
2023-10-09 19:18:26 +00:00
"net"
"net/http"
2023-07-06 22:02:10 +00:00
"net/url"
"os"
2023-08-22 01:24:42 +00:00
"runtime"
2023-08-17 22:20:38 +00:00
"strings"
2023-08-22 01:24:42 +00:00
2023-10-12 16:34:16 +00:00
"github.com/jmorganca/ollama/format"
2023-08-22 01:24:42 +00:00
"github.com/jmorganca/ollama/version"
)
2023-07-18 16:09:45 +00:00
type Client struct {
2023-10-09 19:18:26 +00:00
base *url.URL
http http.Client
2023-07-11 20:05:51 +00:00
}
2023-07-18 16:09:45 +00:00
func checkError(resp *http.Response, body []byte) error {
2023-08-27 04:55:21 +00:00
if resp.StatusCode < http.StatusBadRequest {
2023-07-18 16:09:45 +00:00
return nil
2023-07-11 20:05:51 +00:00
}
2023-07-18 16:09:45 +00:00
apiError := StatusError{StatusCode: resp.StatusCode}
2023-07-11 20:05:51 +00:00
2023-07-18 16:09:45 +00:00
err := json.Unmarshal(body, &apiError)
if err != nil {
// Use the full body as the message if we fail to decode a response.
2023-07-20 18:45:12 +00:00
apiError.ErrorMessage = string(body)
2023-07-18 16:09:45 +00:00
}
return apiError
2023-07-06 22:02:10 +00:00
}
2023-10-09 19:18:26 +00:00
func ClientFromEnvironment() (*Client, error) {
2023-10-26 17:47:41 +00:00
defaultPort := "11434"
2023-10-09 19:18:26 +00:00
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
2023-10-26 17:47:41 +00:00
switch {
case !ok:
2023-10-09 19:18:26 +00:00
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
2023-10-26 17:47:41 +00:00
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
2023-10-09 19:18:26 +00:00
}
2023-10-26 17:47:41 +00:00
// trim trailing slashes
hostport = strings.TrimRight(hostport, "/")
2023-10-09 19:18:26 +00:00
host, port, err := net.SplitHostPort(hostport)
if err != nil {
2023-10-26 17:47:41 +00:00
host, port = "127.0.0.1", defaultPort
2023-10-20 18:32:28 +00:00
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
2023-10-09 19:18:26 +00:00
host = ip.String()
2023-10-20 18:32:28 +00:00
} else if hostport != "" {
host = hostport
2023-10-09 19:18:26 +00:00
}
}
client := Client{
base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
}
2023-11-02 20:10:58 +00:00
mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
2023-10-09 19:18:26 +00:00
if err != nil {
return nil, err
}
2023-10-09 19:18:26 +00:00
proxyURL, err := http.ProxyFromEnvironment(mockRequest)
2023-08-17 22:20:38 +00:00
if err != nil {
2023-10-09 19:18:26 +00:00
return nil, err
2023-07-06 22:02:10 +00:00
}
2023-10-09 19:18:26 +00:00
client.http = http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
2023-07-18 16:09:45 +00:00
}
2023-08-17 22:20:38 +00:00
2023-10-09 19:18:26 +00:00
return &client, nil
2023-07-18 16:09:45 +00:00
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte
var err error
2023-11-14 22:07:40 +00:00
switch reqData := reqData.(type) {
case io.Reader:
// reqData is already an io.Reader
reqBody = reqData
case nil:
// noop
default:
2023-07-18 16:09:45 +00:00
data, err = json.Marshal(reqData)
if err != nil {
return err
}
2023-11-14 22:07:40 +00:00
2023-07-18 16:09:45 +00:00
reqBody = bytes.NewReader(data)
}
2023-10-09 19:18:26 +00:00
requestURL := c.base.JoinPath(path)
2023-08-22 01:24:42 +00:00
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
2023-07-18 16:09:45 +00:00
if err != nil {
return err
}
2023-08-22 01:24:42 +00:00
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
2023-07-18 16:09:45 +00:00
2023-10-09 19:18:26 +00:00
respObj, err := c.http.Do(request)
2023-07-18 16:09:45 +00:00
if err != nil {
return err
}
defer respObj.Body.Close()
respBody, err := io.ReadAll(respObj.Body)
if err != nil {
return err
}
if err := checkError(respObj, respBody); err != nil {
return err
}
if len(respBody) > 0 && respData != nil {
if err := json.Unmarshal(respBody, respData); err != nil {
return err
}
}
return nil
}
2023-10-12 16:34:16 +00:00
const maxBufferSize = 512 * format.KiloByte
2023-10-04 18:09:00 +00:00
2023-07-11 20:05:51 +00:00
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer
if data != nil {
bts, err := json.Marshal(data)
if err != nil {
return err
}
2023-07-06 23:53:14 +00:00
buf = bytes.NewBuffer(bts)
}
2023-10-09 19:18:26 +00:00
requestURL := c.base.JoinPath(path)
2023-08-22 01:24:42 +00:00
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
}
2023-07-06 22:02:10 +00:00
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/x-ndjson")
2023-08-22 01:24:42 +00:00
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
2023-07-04 04:47:00 +00:00
2023-10-09 19:18:26 +00:00
response, err := c.http.Do(request)
2023-07-04 04:47:00 +00:00
if err != nil {
return err
}
2023-07-06 22:02:10 +00:00
defer response.Body.Close()
2023-07-04 04:47:00 +00:00
scanner := bufio.NewScanner(response.Body)
2023-10-04 18:09:00 +00:00
// increase the buffer size to avoid running out of space
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
2023-07-11 20:05:51 +00:00
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
2023-07-20 19:12:08 +00:00
if errorResponse.Error != "" {
return fmt.Errorf(errorResponse.Error)
2023-07-20 19:12:08 +00:00
}
2023-08-27 04:55:21 +00:00
if response.StatusCode >= http.StatusBadRequest {
2023-07-11 20:05:51 +00:00
return StatusError{
2023-07-20 18:45:12 +00:00
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
2023-07-11 20:05:51 +00:00
}
}
2023-07-11 20:05:51 +00:00
if err := fn(bts); err != nil {
return err
}
}
2023-07-06 21:05:55 +00:00
return nil
}
2023-07-06 21:05:55 +00:00
type GenerateResponseFunc func(GenerateResponse) error
2023-07-06 21:05:55 +00:00
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
var resp GenerateResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
2023-07-04 04:47:00 +00:00
}
2023-07-06 16:24:49 +00:00
2023-07-19 01:51:30 +00:00
type PullProgressFunc func(ProgressResponse) error
2023-07-06 21:05:55 +00:00
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
2023-07-19 01:51:30 +00:00
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
2023-07-06 16:24:49 +00:00
}
2023-07-19 01:51:30 +00:00
type PushProgressFunc func(ProgressResponse) error
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
2023-07-19 01:51:30 +00:00
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type CreateProgressFunc func(ProgressResponse) error
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
2023-07-18 16:09:45 +00:00
func (c *Client) List(ctx context.Context) (*ListResponse, error) {
var lr ListResponse
if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
return nil, err
}
return &lr, nil
}
2023-07-20 23:09:23 +00:00
2023-07-24 15:27:28 +00:00
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
return err
}
return nil
}
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
return err
}
return nil
2023-07-20 23:09:23 +00:00
}
2023-09-06 18:04:17 +00:00
func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
var resp ShowResponse
if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
func (c *Client) Heartbeat(ctx context.Context) error {
2023-08-01 18:50:38 +00:00
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
return err
}
return nil
}
2023-11-14 22:07:40 +00:00
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
var response CreateBlobResponse
if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
var statusError StatusError
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
return "", err
}
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
return "", err
}
}
return response.Path, nil
}