better checking for OLLAMA_HOST variable (#3661)

This commit is contained in:
Patrick Devine 2024-04-29 19:14:07 -04:00 committed by GitHub
parent d4ac57e240
commit 9009bedf13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 83 additions and 15 deletions

View file

@ -18,6 +18,7 @@ import (
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
return &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434"
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
}
}
return &Client{
base: &url.URL{
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http: http.DefaultClient,
Host: host,
Port: port,
}, nil
}

View file

@ -1,6 +1,12 @@
package api
import "testing"
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
}
})
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View file

@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
}
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct

View file

@ -831,19 +831,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil {
host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
return err
}
if err := initializeKeypair(); err != nil {
return err
}
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
if err != nil {
return err
}