From 9009bedf13f439b8d355e393805254a4a5eb9af8 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 29 Apr 2024 19:14:07 -0400 Subject: [PATCH] better checking for OLLAMA_HOST variable (#3661) --- api/client.go | 43 +++++++++++++++++++++++++++++++++++-------- api/client_test.go | 44 +++++++++++++++++++++++++++++++++++++++++++- api/types.go | 1 + cmd/cmd.go | 10 ++++------ 4 files changed, 83 insertions(+), 15 deletions(-) diff --git a/api/client.go b/api/client.go index 101382ca..074103cc 100644 --- a/api/client.go +++ b/api/client.go @@ -18,6 +18,7 @@ import ( "net/url" "os" "runtime" + "strconv" "strings" "github.com/ollama/ollama/format" @@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { + ollamaHost, err := GetOllamaHost() + if err != nil { + return nil, err + } + + return &Client{ + base: &url.URL{ + Scheme: ollamaHost.Scheme, + Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), + }, + http: http.DefaultClient, + }, nil +} + +type OllamaHost struct { + Scheme string + Host string + Port string +} + +func GetOllamaHost() (OllamaHost, error) { defaultPort := "11434" - scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") + hostVar := os.Getenv("OLLAMA_HOST") + hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) + + scheme, hostport, ok := strings.Cut(hostVar, "://") switch { case !ok: - scheme, hostport = "http", os.Getenv("OLLAMA_HOST") + scheme, hostport = "http", hostVar case scheme == "http": defaultPort = "80" case scheme == "https": @@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) { } } - return &Client{ - base: &url.URL{ - Scheme: scheme, - Host: net.JoinHostPort(host, port), - }, - http: http.DefaultClient, + if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { + return OllamaHost{}, ErrInvalidHostPort + } + + return OllamaHost{ + Scheme: scheme, + Host: host, + Port: port, }, nil } diff --git a/api/client_test.go b/api/client_test.go index 0eafedca..b2c51d00 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,12 @@ package api -import "testing" +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) func TestClientFromEnvironment(t *testing.T) { type testCase struct { @@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) { } }) } + + hostTestCases := map[string]*testCase{ + "empty": {value: "", expect: "127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, + "only port": {value: ":1234", expect: ":1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, + "zero port": {value: ":0", expect: ":0"}, + "too large port": {value: ":66000", err: ErrInvalidHostPort}, + "too small port": {value: ":-1", err: ErrInvalidHostPort}, + "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, + "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, + "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, + "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, + "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, + "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, + "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, + "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, + } + + for k, v := range hostTestCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + + oh, err := GetOllamaHost() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if err == nil { + host := net.JoinHostPort(oh.Host, oh.Port) + assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + } + }) + } } diff --git a/api/types.go b/api/types.go index 9200949c..7cfd5ff7 100644 --- a/api/types.go +++ b/api/types.go @@ -309,6 +309,7 @@ func (m *Metrics) Summary() { } var ErrInvalidOpts = errors.New("invalid options") +var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") func (opts *Options) FromMap(m map[string]interface{}) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct diff --git a/cmd/cmd.go b/cmd/cmd.go index 0a1dc7ed..a1eb8eba 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -831,19 +831,17 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'")) + // retrieve the OLLAMA_HOST environment variable + ollamaHost, err := api.GetOllamaHost() if err != nil { - host, port = "127.0.0.1", "11434" - if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil { - host = ip.String() - } + return err } if err := initializeKeypair(); err != nil { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) + ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port)) if err != nil { return err }