diff --git a/api/client.go b/api/client.go index 99c1daa1..961cd417 100644 --- a/api/client.go +++ b/api/client.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "os" @@ -16,14 +17,9 @@ import ( "github.com/jmorganca/ollama/version" ) -const DefaultHost = "127.0.0.1:11434" - -var envHost = os.Getenv("OLLAMA_HOST") - type Client struct { - Base url.URL - HTTP http.Client - Headers http.Header + base *url.URL + http http.Client } func checkError(resp *http.Response, body []byte) error { @@ -42,34 +38,44 @@ func checkError(resp *http.Response, body []byte) error { return apiError } -// Host returns the default host to use for the client. It is determined in the following order: -// 1. The OLLAMA_HOST environment variable -// 2. The default host (localhost:11434) -func Host() string { - if envHost != "" { - return envHost - } - return DefaultHost -} - -// FromEnv creates a new client using Host() as the host. An error is returns -// if the host is invalid. -func FromEnv() (*Client, error) { - h := Host() - if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") { - h = "http://" + h +func ClientFromEnvironment() (*Client, error) { + scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") + if !ok { + scheme, hostport = "http", os.Getenv("OLLAMA_HOST") } - u, err := url.Parse(h) + host, port, err := net.SplitHostPort(hostport) if err != nil { - return nil, fmt.Errorf("could not parse host: %w", err) + host, port = "127.0.0.1", "11434" + if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil { + host = ip.String() + } } - if u.Port() == "" { - u.Host += ":11434" + client := Client{ + base: &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + }, } - return &Client{Base: *u, HTTP: http.Client{}}, nil + mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil) + if err != nil { + return nil, err + } + + proxyURL, err := http.ProxyFromEnvironment(mockRequest) + if err != nil { + return nil, err + } + + client.http = http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + + return &client, nil } func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { @@ -84,7 +90,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData reqBody = bytes.NewReader(data) } - requestURL := c.Base.JoinPath(path) + requestURL := c.base.JoinPath(path) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) if err != nil { return err @@ -94,11 +100,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData request.Header.Set("Accept", "application/json") request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) - for k, v := range c.Headers { - request.Header[k] = v - } - - respObj, err := c.HTTP.Do(request) + respObj, err := c.http.Do(request) if err != nil { return err } @@ -134,7 +136,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f buf = bytes.NewBuffer(bts) } - requestURL := c.Base.JoinPath(path) + requestURL := c.base.JoinPath(path) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) if err != nil { return err @@ -144,7 +146,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f request.Header.Set("Accept", "application/json") request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) - response, err := http.DefaultClient.Do(request) + response, err := c.http.Do(request) if err != nil { return err } diff --git a/cmd/cmd.go b/cmd/cmd.go index c978c0e8..f3fdd793 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -61,7 +61,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -119,7 +119,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } func RunHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -144,7 +144,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } func PushHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -188,7 +188,7 @@ func PushHandler(cmd *cobra.Command, args []string) error { } func ListHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -221,7 +221,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { } func DeleteHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -237,7 +237,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { } func ShowHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -315,7 +315,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error { } func CopyHandler(cmd *cobra.Command, args []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -338,7 +338,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { } func pull(model string, insecure bool) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -406,7 +406,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error { type generateContextKey string func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } @@ -906,7 +906,7 @@ func startMacApp(client *api.Client) error { } func checkServerHeartbeat(_ *cobra.Command, _ []string) error { - client, err := api.FromEnv() + client, err := api.ClientFromEnvironment() if err != nil { return err } diff --git a/server/images.go b/server/images.go index 030c6655..0945b0a4 100644 --- a/server/images.go +++ b/server/images.go @@ -1486,7 +1486,18 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header req.ContentLength = contentLength } - resp, err := http.DefaultClient.Do(req) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil { + return nil, err + } + + client := http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + + resp, err := client.Do(req) if err != nil { return nil, err }