diff --git a/api/client.go b/api/client.go index 7753eb91..a1ddf667 100644 --- a/api/client.go +++ b/api/client.go @@ -9,10 +9,17 @@ import ( "io" "net/http" "net/url" + "os" +) + +const DefaultHost = "localhost:11434" + +var ( + envHost = os.Getenv("OLLAMA_HOST") ) type Client struct { - base url.URL + Base url.URL HTTP http.Client Headers http.Header } @@ -33,14 +40,34 @@ func checkError(resp *http.Response, body []byte) error { return apiError } +// Host returns the default host to use for the client. It is determined in the following order: +// 1. The OLLAMA_HOST environment variable +// 2. The default host (localhost:11434) +func Host() string { + if envHost != "" { + return envHost + } + return DefaultHost +} + +// FromEnv creates a new client using Host() as the host. An error is returns +// if the host is invalid. +func FromEnv() (*Client, error) { + u, err := url.Parse(Host()) + if err != nil { + return nil, err + } + return &Client{Base: *u}, nil +} + func NewClient(hosts ...string) *Client { - host := "127.0.0.1:11434" + host := DefaultHost if len(hosts) > 0 { host = hosts[0] } return &Client{ - base: url.URL{Scheme: "http", Host: host}, + Base: url.URL{Scheme: "http", Host: host}, HTTP: http.Client{}, } } @@ -57,7 +84,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData reqBody = bytes.NewReader(data) } - url := c.base.JoinPath(path).String() + url := c.Base.JoinPath(path).String() req, err := http.NewRequestWithContext(ctx, method, url, reqBody) if err != nil { @@ -105,7 +132,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f buf = bytes.NewBuffer(bts) } - request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), buf) + request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf) if err != nil { return err } diff --git a/cmd/cmd.go b/cmd/cmd.go index 53c039ba..8f7dee30 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -39,7 +39,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } var spinner *Spinner @@ -117,7 +120,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { } func PushHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } insecure, err := cmd.Flags().GetBool("insecure") if err != nil { @@ -153,7 +159,10 @@ func PushHandler(cmd *cobra.Command, args []string) error { } func ListHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } models, err := client.List(context.Background()) if err != nil { @@ -183,7 +192,10 @@ func ListHandler(cmd *cobra.Command, args []string) error { } func DeleteHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } req := api.DeleteRequest{Name: args[0]} if err := client.Delete(context.Background(), &req); err != nil { @@ -194,7 +206,10 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { } func CopyHandler(cmd *cobra.Command, args []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } req := api.CopyRequest{Source: args[0], Destination: args[1]} if err := client.Copy(context.Background(), &req); err != nil { @@ -214,7 +229,10 @@ func PullHandler(cmd *cobra.Command, args []string) error { } func pull(model string, insecure bool) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } var currentDigest string var bar *progressbar.ProgressBar @@ -261,7 +279,10 @@ type generateContextKey string func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } spinner := NewSpinner("") go spinner.Spin(60 * time.Millisecond) @@ -644,7 +665,10 @@ func startMacApp(client *api.Client) error { } func checkServerHeartbeat(_ *cobra.Command, _ []string) error { - client := api.NewClient() + client, err := api.FromEnv() + if err != nil { + return err + } if err := client.Heartbeat(context.Background()); err != nil { if !strings.Contains(err.Error(), "connection refused") { return err