Merge pull request #916 from jmorganca/mxyng/fix-client-host

fix(client): trim trailing slash
This commit is contained in:
Michael Yang 2023-10-26 12:24:12 -07:00 committed by GitHub
commit b88cc0fac9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 2 deletions

View file

@ -44,14 +44,24 @@ func checkError(resp *http.Response, body []byte) error {
} }
func ClientFromEnvironment() (*Client, error) { func ClientFromEnvironment() (*Client, error) {
defaultPort := "11434"
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
if !ok { switch {
case !ok:
scheme, hostport = "http", os.Getenv("OLLAMA_HOST") scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
} }
// trim trailing slashes
hostport = strings.TrimRight(hostport, "/")
host, port, err := net.SplitHostPort(hostport) host, port, err := net.SplitHostPort(hostport)
if err != nil { if err != nil {
host, port = "127.0.0.1", "11434" host, port = "127.0.0.1", defaultPort
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
host = ip.String() host = ip.String()
} else if hostport != "" { } else if hostport != "" {

43
api/client_test.go Normal file
View file

@ -0,0 +1,43 @@
package api
import "testing"
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
value string
expect string
err error
}
testCases := map[string]*testCase{
"empty": {value: "", expect: "http://127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
"only port": {value: ":1234", expect: "http://:1234"},
"address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
"scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
"scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
"scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "http://example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
"scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
"scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
"trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
}
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if client.base.String() != v.expect {
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
}
})
}
}