From 67e593e355e176a1a2841c88eef103ef7c314cbf Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 16 Aug 2023 08:03:48 -0700 Subject: [PATCH] cmd: support OLLAMA_CLIENT_HOST environment variable (#262) * cmd: support OLLAMA_HOST environment variable This commit adds support for the OLLAMA_HOST environment variable. This variable can be used to specify the host to which the client should connect. This is useful when the client is running somewhere other than the host where the server is running. The new api.FromEnv function is used to read configure clients from the environment. Clients wishing to use the environment variable being consistent with the Ollama CLI can use this new function. * Update api/client.go Co-authored-by: Jeffrey Morgan * Update api/client.go Co-authored-by: Jeffrey Morgan --------- Co-authored-by: Jeffrey Morgan --- api/client.go | 37 ++++++++++++++++++++++++++++++++----- cmd/cmd.go | 40 ++++++++++++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/api/client.go b/api/client.go index 7753eb91..a1ddf667 100644 --- a/api/client.go +++ b/api/client.go @@ -9,10 +9,17 @@ import ( "io" "net/http" "net/url" + "os" +) + +const DefaultHost = "localhost:11434" + +var ( + envHost = os.Getenv("OLLAMA_HOST") ) type Client struct { - base url.URL + Base url.URL HTTP http.Client Headers http.Header } @@ -33,14 +40,34 @@ func checkError(resp *http.Response, body []byte) error { return apiError } +// Host returns the default host to use for the client. It is determined in the following order: +// 1. The OLLAMA_HOST environment variable +// 2. The default host (localhost:11434) +func Host() string { + if envHost != "" { + return envHost + } + return DefaultHost +} + +// FromEnv creates a new client using Host() as the host. An error is returns +// if the host is invalid. +func FromEnv() (*Client, error) { + u, err := url.Parse(Host()) + if err != nil { + return nil, err + } + return &Client{Base: *u}, nil +} + func NewClient(hosts ...string) *Client { - host := "127.0.0.1:11434" + host := DefaultHost if len(hosts) > 0 { host = hosts[0] } return &Client{ - base: url.URL{Scheme: "http", Host: host}, + Base: url.URL{Scheme: "http", Host: host}, HTTP: http.Client{}, } } @@ -57,7 +84,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData reqBody = bytes.NewReader(data) } - url := c.base.JoinPath(path).String() + url := c.Base.JoinPath(path).String() req, err := http.NewRequestWithContext(ctx, method, url, reqBody) if err != nil { @@ -105,7 +132,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f buf = bytes.NewBuffer(bts) } - request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), buf) + request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf) if err != nil { return err } diff --git a/cmd/cmd.go b/cmd/cmd.go index 53c039ba..8f7dee30 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -39,7 +39,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } var spinner *Spinner @@ -117,7 +120,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { } func PushHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } insecure, err := cmd.Flags().GetBool("insecure") if err != nil { @@ -153,7 +159,10 @@ func PushHandler(cmd *cobra.Command, args []string) error { } func ListHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } models, err := client.List(context.Background()) if err != nil { @@ -183,7 +192,10 @@ func ListHandler(cmd *cobra.Command, args []string) error { } func DeleteHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } req := api.DeleteRequest{Name: args[0]} if err := client.Delete(context.Background(), &req); err != nil { @@ -194,7 +206,10 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { } func CopyHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } req := api.CopyRequest{Source: args[0], Destination: args[1]} if err := client.Copy(context.Background(), &req); err != nil { @@ -214,7 +229,10 @@ func PullHandler(cmd *cobra.Command, args []string) error { } func pull(model string, insecure bool) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } var currentDigest string var bar *progressbar.ProgressBar @@ -261,7 +279,10 @@ type generateContextKey string func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } spinner := NewSpinner("") go spinner.Spin(60 * time.Millisecond) @@ -644,7 +665,10 @@ func startMacApp(client *api.Client) error { } func checkServerHeartbeat(_ *cobra.Command, _ []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } if err := client.Heartbeat(context.Background()); err != nil { if !strings.Contains(err.Error(), "connection refused") { return err