Merge pull request #743 from jmorganca/mxyng/http-proxy

handle upstream proxies
This commit is contained in:
Michael Yang 2023-10-10 09:59:06 -07:00 committed by GitHub
commit 0040f543a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 47 deletions

View file

@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
@ -16,14 +17,9 @@ import (
"github.com/jmorganca/ollama/version"
)
const DefaultHost = "127.0.0.1:11434"
var envHost = os.Getenv("OLLAMA_HOST")
type Client struct {
Base url.URL
HTTP http.Client
Headers http.Header
base *url.URL
http http.Client
}
func checkError(resp *http.Response, body []byte) error {
@ -42,34 +38,44 @@ func checkError(resp *http.Response, body []byte) error {
return apiError
}
// Host returns the default host to use for the client. It is determined in the following order:
// 1. The OLLAMA_HOST environment variable
// 2. The default host (localhost:11434)
func Host() string {
if envHost != "" {
return envHost
}
return DefaultHost
}
// FromEnv creates a new client using Host() as the host. An error is returns
// if the host is invalid.
func FromEnv() (*Client, error) {
h := Host()
if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
h = "http://" + h
func ClientFromEnvironment() (*Client, error) {
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
if !ok {
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
}
u, err := url.Parse(h)
host, port, err := net.SplitHostPort(hostport)
if err != nil {
return nil, fmt.Errorf("could not parse host: %w", err)
host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
}
if u.Port() == "" {
u.Host += ":11434"
client := Client{
base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
}
return &Client{Base: *u, HTTP: http.Client{}}, nil
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
if err != nil {
return nil, err
}
proxyURL, err := http.ProxyFromEnvironment(mockRequest)
if err != nil {
return nil, err
}
client.http = http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
return &client, nil
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
@ -84,7 +90,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
reqBody = bytes.NewReader(data)
}
requestURL := c.Base.JoinPath(path)
requestURL := c.base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil {
return err
@ -94,11 +100,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
for k, v := range c.Headers {
request.Header[k] = v
}
respObj, err := c.HTTP.Do(request)
respObj, err := c.http.Do(request)
if err != nil {
return err
}
@ -134,7 +136,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
buf = bytes.NewBuffer(bts)
}
requestURL := c.Base.JoinPath(path)
requestURL := c.base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
@ -144,7 +146,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
response, err := http.DefaultClient.Do(request)
response, err := c.http.Do(request)
if err != nil {
return err
}

View file

@ -61,7 +61,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -119,7 +119,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -144,7 +144,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -188,7 +188,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
}
func ListHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -221,7 +221,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
}
func DeleteHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -237,7 +237,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -315,7 +315,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
}
func CopyHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -338,7 +338,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
func pull(model string, insecure bool) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -406,7 +406,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
@ -906,7 +906,7 @@ func startMacApp(client *api.Client) error {
}
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
client, err := api.FromEnv()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}

View file

@ -1486,7 +1486,18 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength
}
resp, err := http.DefaultClient.Do(req)
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}