correct precedence of serve params (args over env over default)
This commit is contained in:
parent
fb593b7bfc
commit
93492f1e18
2 changed files with 132 additions and 18 deletions
47
cmd/cmd.go
47
cmd/cmd.go
|
@ -513,28 +513,39 @@ func generateBatch(cmd *cobra.Command, model string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// getRunServerParams takes a command and the environment variables and returns the correct params
|
||||
// given the order of precedence: command line args (highest), environment variables, defaults (lowest)
|
||||
func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []string, err error) {
|
||||
host = os.Getenv("OLLAMA_HOST")
|
||||
hostFlag := cmd.Flags().Lookup("host")
|
||||
if hostFlag == nil {
|
||||
return "", "", nil, errors.New("host unset")
|
||||
}
|
||||
if hostFlag.Changed || host == "" {
|
||||
host = hostFlag.Value.String()
|
||||
}
|
||||
port = os.Getenv("OLLAMA_PORT")
|
||||
portFlag := cmd.Flags().Lookup("port")
|
||||
if portFlag == nil {
|
||||
return "", "", nil, errors.New("port unset")
|
||||
}
|
||||
if portFlag.Changed || port == "" {
|
||||
port = portFlag.Value.String()
|
||||
}
|
||||
extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins")
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
return host, port, extraOrigins, nil
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, err := cmd.Flags().GetString("host")
|
||||
if err != nil {
|
||||
return errors.New("host unset")
|
||||
}
|
||||
if os.Getenv("OLLAMA_HOST") != "" {
|
||||
host = os.Getenv("OLLAMA_HOST")
|
||||
}
|
||||
port, err := cmd.Flags().GetString("port")
|
||||
if err != nil {
|
||||
return errors.New("port unset")
|
||||
}
|
||||
|
||||
if os.Getenv("OLLAMA_PORT") != "" {
|
||||
port = os.Getenv("OLLAMA_PORT")
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
|
||||
host, port, extraOrigins, err := getRunServerParams(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins")
|
||||
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
103
cmd/cmd_test.go
Normal file
103
cmd/cmd_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetRunServerParams(t *testing.T) {
|
||||
t.Run("default values", func(t *testing.T) {
|
||||
cmd := NewCLI()
|
||||
serveCmd, _, err := cmd.Find([]string{"serve"})
|
||||
if err != nil {
|
||||
t.Errorf("expected serve command, got %s", err)
|
||||
}
|
||||
host, port, extraOrigins, err := getRunServerParams(serveCmd)
|
||||
// assertions
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error, got %s", err)
|
||||
}
|
||||
if host != "127.0.0.1" {
|
||||
t.Errorf("unexpected host, got %s", host)
|
||||
}
|
||||
if port != "11434" {
|
||||
t.Errorf("unexpected port, got %s", port)
|
||||
}
|
||||
if len(extraOrigins) != 0 {
|
||||
t.Errorf("unexpected origins, got %s", extraOrigins)
|
||||
}
|
||||
})
|
||||
t.Run("environment variables take precedence over default", func(t *testing.T) {
|
||||
cmd := NewCLI()
|
||||
serveCmd, _, err := cmd.Find([]string{"serve"})
|
||||
if err != nil {
|
||||
t.Errorf("expected serve command, got %s", err)
|
||||
}
|
||||
// setup environment variables
|
||||
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
|
||||
if err != nil {
|
||||
t.Errorf("could not set env var")
|
||||
}
|
||||
err = os.Setenv("OLLAMA_PORT", "9999")
|
||||
if err != nil {
|
||||
t.Errorf("could not set env var")
|
||||
}
|
||||
defer func() {
|
||||
os.Unsetenv("OLLAMA_HOST")
|
||||
os.Unsetenv("OLLAMA_PORT")
|
||||
}()
|
||||
|
||||
host, port, extraOrigins, err := getRunServerParams(serveCmd)
|
||||
// assertions
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error, got %s", err)
|
||||
}
|
||||
if host != "0.0.0.0" {
|
||||
t.Errorf("unexpected host, got %s", host)
|
||||
}
|
||||
if port != "9999" {
|
||||
t.Errorf("unexpected port, got %s", port)
|
||||
}
|
||||
if len(extraOrigins) != 0 {
|
||||
t.Errorf("unexpected origins, got %s", extraOrigins)
|
||||
}
|
||||
})
|
||||
t.Run("command line args take precedence over env vars", func(t *testing.T) {
|
||||
cmd := NewCLI()
|
||||
serveCmd, _, err := cmd.Find([]string{"serve"})
|
||||
if err != nil {
|
||||
t.Errorf("expected serve command, got %s", err)
|
||||
}
|
||||
// setup environment variables
|
||||
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
|
||||
if err != nil {
|
||||
t.Errorf("could not set env var")
|
||||
}
|
||||
err = os.Setenv("OLLAMA_PORT", "9999")
|
||||
if err != nil {
|
||||
t.Errorf("could not set env var")
|
||||
}
|
||||
defer func() {
|
||||
os.Unsetenv("OLLAMA_HOST")
|
||||
os.Unsetenv("OLLAMA_PORT")
|
||||
}()
|
||||
// now set command flags
|
||||
serveCmd.Flags().Set("host", "localhost")
|
||||
serveCmd.Flags().Set("port", "8888")
|
||||
serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1")
|
||||
|
||||
host, port, extraOrigins, err := getRunServerParams(serveCmd)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error, got %s", err)
|
||||
}
|
||||
if host != "localhost" {
|
||||
t.Errorf("unexpected host, got %s", host)
|
||||
}
|
||||
if port != "8888" {
|
||||
t.Errorf("unexpected port, got %s", port)
|
||||
}
|
||||
if len(extraOrigins) != 2 {
|
||||
t.Errorf("expected two origins, got length %d", len(extraOrigins))
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in a new issue