Merge pull request #743 from jmorganca/mxyng/http-proxy
handle upstream proxies
This commit is contained in:
commit
0040f543a2
3 changed files with 60 additions and 47 deletions
|
@ -7,6 +7,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -16,14 +17,9 @@ import (
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultHost = "127.0.0.1:11434"
|
|
||||||
|
|
||||||
var envHost = os.Getenv("OLLAMA_HOST")
|
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Base url.URL
|
base *url.URL
|
||||||
HTTP http.Client
|
http http.Client
|
||||||
Headers http.Header
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkError(resp *http.Response, body []byte) error {
|
func checkError(resp *http.Response, body []byte) error {
|
||||||
|
@ -42,34 +38,44 @@ func checkError(resp *http.Response, body []byte) error {
|
||||||
return apiError
|
return apiError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host returns the default host to use for the client. It is determined in the following order:
|
func ClientFromEnvironment() (*Client, error) {
|
||||||
// 1. The OLLAMA_HOST environment variable
|
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||||
// 2. The default host (localhost:11434)
|
if !ok {
|
||||||
func Host() string {
|
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||||
if envHost != "" {
|
|
||||||
return envHost
|
|
||||||
}
|
|
||||||
return DefaultHost
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnv creates a new client using Host() as the host. An error is returns
|
|
||||||
// if the host is invalid.
|
|
||||||
func FromEnv() (*Client, error) {
|
|
||||||
h := Host()
|
|
||||||
if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
|
|
||||||
h = "http://" + h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := url.Parse(h)
|
host, port, err := net.SplitHostPort(hostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse host: %w", err)
|
host, port = "127.0.0.1", "11434"
|
||||||
|
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||||
|
host = ip.String()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Port() == "" {
|
client := Client{
|
||||||
u.Host += ":11434"
|
base: &url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: net.JoinHostPort(host, port),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{Base: *u, HTTP: http.Client{}}, nil
|
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := http.ProxyFromEnvironment(mockRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client.http = http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||||
|
@ -84,7 +90,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||||
reqBody = bytes.NewReader(data)
|
reqBody = bytes.NewReader(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestURL := c.Base.JoinPath(path)
|
requestURL := c.base.JoinPath(path)
|
||||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -94,11 +100,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||||
request.Header.Set("Accept", "application/json")
|
request.Header.Set("Accept", "application/json")
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
|
||||||
for k, v := range c.Headers {
|
respObj, err := c.http.Do(request)
|
||||||
request.Header[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
respObj, err := c.HTTP.Do(request)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -134,7 +136,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||||
buf = bytes.NewBuffer(bts)
|
buf = bytes.NewBuffer(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestURL := c.Base.JoinPath(path)
|
requestURL := c.base.JoinPath(path)
|
||||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -144,7 +146,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||||
request.Header.Set("Accept", "application/json")
|
request.Header.Set("Accept", "application/json")
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
|
||||||
response, err := http.DefaultClient.Do(request)
|
response, err := c.http.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
20
cmd/cmd.go
20
cmd/cmd.go
|
@ -61,7 +61,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListHandler(cmd *cobra.Command, args []string) error {
|
func ListHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -221,7 +221,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteHandler(cmd *cobra.Command, args []string) error {
|
func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -315,7 +315,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyHandler(cmd *cobra.Command, args []string) error {
|
func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -338,7 +338,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func pull(model string, insecure bool) error {
|
func pull(model string, insecure bool) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -406,7 +406,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
|
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -906,7 +906,7 @@ func startMacApp(client *api.Client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
|
||||||
client, err := api.FromEnv()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1486,7 +1486,18 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
||||||
req.ContentLength = contentLength
|
req.ContentLength = contentLength
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
proxyURL, err := http.ProxyFromEnvironment(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue