handle client proxy

This commit is contained in:
Michael Yang 2023-10-09 12:18:26 -07:00
parent f6e98334e4
commit 2cfffea02e
2 changed files with 48 additions and 46 deletions

View file

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -16,14 +17,9 @@ import (
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
const DefaultHost = "127.0.0.1:11434"
var envHost = os.Getenv("OLLAMA_HOST")
type Client struct { type Client struct {
Base url.URL base *url.URL
HTTP http.Client http http.Client
Headers http.Header
} }
func checkError(resp *http.Response, body []byte) error { func checkError(resp *http.Response, body []byte) error {
@ -42,34 +38,44 @@ func checkError(resp *http.Response, body []byte) error {
return apiError return apiError
} }
// Host returns the default host to use for the client. It is determined in the following order: func ClientFromEnvironment() (*Client, error) {
// 1. The OLLAMA_HOST environment variable scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
// 2. The default host (localhost:11434) if !ok {
func Host() string { scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
if envHost != "" {
return envHost
}
return DefaultHost
}
// FromEnv creates a new client using Host() as the host. An error is returns
// if the host is invalid.
func FromEnv() (*Client, error) {
h := Host()
if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
h = "http://" + h
} }
u, err := url.Parse(h) host, port, err := net.SplitHostPort(hostport)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not parse host: %w", err) host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
} }
if u.Port() == "" { client := Client{
u.Host += ":11434" base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
} }
return &Client{Base: *u, HTTP: http.Client{}}, nil mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
if err != nil {
return nil, err
}
proxyURL, err := http.ProxyFromEnvironment(mockRequest)
if err != nil {
return nil, err
}
client.http = http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
return &client, nil
} }
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
@ -84,7 +90,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
reqBody = bytes.NewReader(data) reqBody = bytes.NewReader(data)
} }
requestURL := c.Base.JoinPath(path) requestURL := c.base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil { if err != nil {
return err return err
@ -94,11 +100,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json") request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
for k, v := range c.Headers { respObj, err := c.http.Do(request)
request.Header[k] = v
}
respObj, err := c.HTTP.Do(request)
if err != nil { if err != nil {
return err return err
} }
@ -134,7 +136,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
buf = bytes.NewBuffer(bts) buf = bytes.NewBuffer(bts)
} }
requestURL := c.Base.JoinPath(path) requestURL := c.base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil { if err != nil {
return err return err
@ -144,7 +146,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/json") request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
response, err := http.DefaultClient.Do(request) response, err := c.http.Do(request)
if err != nil { if err != nil {
return err return err
} }

View file

@ -61,7 +61,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -119,7 +119,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -144,7 +144,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -188,7 +188,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
} }
func ListHandler(cmd *cobra.Command, args []string) error { func ListHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -221,7 +221,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
} }
func DeleteHandler(cmd *cobra.Command, args []string) error { func DeleteHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -237,7 +237,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
} }
func ShowHandler(cmd *cobra.Command, args []string) error { func ShowHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -315,7 +315,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
} }
func CopyHandler(cmd *cobra.Command, args []string) error { func CopyHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -338,7 +338,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
func pull(model string, insecure bool) error { func pull(model string, insecure bool) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -406,7 +406,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
type generateContextKey string type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error { func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
@ -906,7 +906,7 @@ func startMacApp(client *api.Client) error {
} }
func checkServerHeartbeat(_ *cobra.Command, _ []string) error { func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
client, err := api.FromEnv() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }