Merge pull request #916 from jmorganca/mxyng/fix-client-host
fix(client): trim trailing slash
This commit is contained in:
commit
b88cc0fac9
2 changed files with 55 additions and 2 deletions
|
@ -44,14 +44,24 @@ func checkError(resp *http.Response, body []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientFromEnvironment() (*Client, error) {
|
func ClientFromEnvironment() (*Client, error) {
|
||||||
|
defaultPort := "11434"
|
||||||
|
|
||||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||||
if !ok {
|
switch {
|
||||||
|
case !ok:
|
||||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||||
|
case scheme == "http":
|
||||||
|
defaultPort = "80"
|
||||||
|
case scheme == "https":
|
||||||
|
defaultPort = "443"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trim trailing slashes
|
||||||
|
hostport = strings.TrimRight(hostport, "/")
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(hostport)
|
host, port, err := net.SplitHostPort(hostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
host, port = "127.0.0.1", "11434"
|
host, port = "127.0.0.1", defaultPort
|
||||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||||
host = ip.String()
|
host = ip.String()
|
||||||
} else if hostport != "" {
|
} else if hostport != "" {
|
||||||
|
|
43
api/client_test.go
Normal file
43
api/client_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
value string
|
||||||
|
expect string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]*testCase{
|
||||||
|
"empty": {value: "", expect: "http://127.0.0.1:11434"},
|
||||||
|
"only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
|
||||||
|
"only port": {value: ":1234", expect: "http://:1234"},
|
||||||
|
"address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
|
||||||
|
"scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
|
||||||
|
"scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
|
||||||
|
"scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
|
||||||
|
"hostname": {value: "example.com", expect: "http://example.com:11434"},
|
||||||
|
"hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
|
||||||
|
"scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
|
||||||
|
"scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
|
||||||
|
"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
|
||||||
|
"trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
|
||||||
|
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range testCases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
|
||||||
|
client, err := ClientFromEnvironment()
|
||||||
|
if err != v.err {
|
||||||
|
t.Fatalf("expected %s, got %s", v.err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.base.String() != v.expect {
|
||||||
|
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue