From fb593b7bfc5c193f053726a1026279852a9184c9 Mon Sep 17 00:00:00 2001 From: cmiller01 Date: Mon, 7 Aug 2023 03:34:37 +0000 Subject: [PATCH 1/2] pass flags to `serve` to allow setting allowed-origins + host and port * resolves: https://github.com/jmorganca/ollama/issues/300 and https://github.com/jmorganca/ollama/issues/282 * example usage: ``` ollama serve --port 9999 --allowed-origins "http://foo.example.com,http://192.0.0.1" ``` --- cmd/cmd.go | 30 ++++++++++++++++++++++-------- server/routes.go | 6 ++++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 0c3c6f97..041d4421 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -513,23 +513,33 @@ func generateBatch(cmd *cobra.Command, model string) error { return nil } -func RunServer(_ *cobra.Command, _ []string) error { - host := os.Getenv("OLLAMA_HOST") - if host == "" { - host = "127.0.0.1" +func RunServer(cmd *cobra.Command, _ []string) error { + host, err := cmd.Flags().GetString("host") + if err != nil { + return errors.New("host unset") + } + if os.Getenv("OLLAMA_HOST") != "" { + host = os.Getenv("OLLAMA_HOST") + } + port, err := cmd.Flags().GetString("port") + if err != nil { + return errors.New("port unset") } - port := os.Getenv("OLLAMA_PORT") - if port == "" { - port = "11434" + if os.Getenv("OLLAMA_PORT") != "" { + port = os.Getenv("OLLAMA_PORT") } ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port)) if err != nil { return err } + extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins") + if err != nil { + return err + } - return server.Serve(ln) + return server.Serve(ln, extraOrigins) } func startMacApp(client *api.Client) error { @@ -621,6 +631,10 @@ func NewCLI() *cobra.Command { RunE: RunServer, } + serveCmd.Flags().String("port", "11434", "Port to listen on, may also use OLLAMA_PORT environment variable") + serveCmd.Flags().String("host", "127.0.0.1", "Host listen address, may also use OLLAMA_HOST environment variable") + serveCmd.Flags().StringSlice("allowed-origins", []string{}, "Additional allowed CORS origins (outside of localhost), specify as comma-separated list") + pullCmd := &cobra.Command{ Use: "pull MODEL", Short: "Pull a model from a registry", diff --git a/server/routes.go b/server/routes.go index 5e8a356f..83afef1a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -301,11 +301,11 @@ func CopyModelHandler(c *gin.Context) { } } -func Serve(ln net.Listener) error { +func Serve(ln net.Listener, extraOrigins []string) error { config := cors.DefaultConfig() config.AllowWildcard = true // only allow http/https from localhost - config.AllowOrigins = []string{ + allowedOrigins := []string{ "http://localhost", "http://localhost:*", "https://localhost", @@ -315,6 +315,8 @@ func Serve(ln net.Listener) error { "https://127.0.0.1", "https://127.0.0.1:*", } + allowedOrigins = append(allowedOrigins, extraOrigins...) + config.AllowOrigins = allowedOrigins r := gin.Default() r.Use(cors.New(config)) From 93492f1e18cb4d9003289a941c159612bb250b81 Mon Sep 17 00:00:00 2001 From: cmiller01 Date: Mon, 7 Aug 2023 19:55:20 +0000 Subject: [PATCH 2/2] correct precedence of serve params (args over env over default) --- cmd/cmd.go | 47 +++++++++++++--------- cmd/cmd_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 18 deletions(-) create mode 100644 cmd/cmd_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 041d4421..9526c864 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -513,28 +513,39 @@ func generateBatch(cmd *cobra.Command, model string) error { return nil } +// getRunServerParams takes a command and the environment variables and returns the correct params +// given the order of precedence: command line args (highest), environment variables, defaults (lowest) +func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []string, err error) { + host = os.Getenv("OLLAMA_HOST") + hostFlag := cmd.Flags().Lookup("host") + if hostFlag == nil { + return "", "", nil, errors.New("host unset") + } + if hostFlag.Changed || host == "" { + host = hostFlag.Value.String() + } + port = os.Getenv("OLLAMA_PORT") + portFlag := cmd.Flags().Lookup("port") + if portFlag == nil { + return "", "", nil, errors.New("port unset") + } + if portFlag.Changed || port == "" { + port = portFlag.Value.String() + } + extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins") + if err != nil { + return "", "", nil, err + } + return host, port, extraOrigins, nil +} + func RunServer(cmd *cobra.Command, _ []string) error { - host, err := cmd.Flags().GetString("host") - if err != nil { - return errors.New("host unset") - } - if os.Getenv("OLLAMA_HOST") != "" { - host = os.Getenv("OLLAMA_HOST") - } - port, err := cmd.Flags().GetString("port") - if err != nil { - return errors.New("port unset") - } - - if os.Getenv("OLLAMA_PORT") != "" { - port = os.Getenv("OLLAMA_PORT") - } - - ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port)) + host, port, extraOrigins, err := getRunServerParams(cmd) if err != nil { return err } - extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins") + + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port)) if err != nil { return err } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go new file mode 100644 index 00000000..6633e385 --- /dev/null +++ b/cmd/cmd_test.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "os" + "testing" +) + +func TestGetRunServerParams(t *testing.T) { + t.Run("default values", func(t *testing.T) { + cmd := NewCLI() + serveCmd, _, err := cmd.Find([]string{"serve"}) + if err != nil { + t.Errorf("expected serve command, got %s", err) + } + host, port, extraOrigins, err := getRunServerParams(serveCmd) + // assertions + if err != nil { + t.Errorf("unexpected error, got %s", err) + } + if host != "127.0.0.1" { + t.Errorf("unexpected host, got %s", host) + } + if port != "11434" { + t.Errorf("unexpected port, got %s", port) + } + if len(extraOrigins) != 0 { + t.Errorf("unexpected origins, got %s", extraOrigins) + } + }) + t.Run("environment variables take precedence over default", func(t *testing.T) { + cmd := NewCLI() + serveCmd, _, err := cmd.Find([]string{"serve"}) + if err != nil { + t.Errorf("expected serve command, got %s", err) + } + // setup environment variables + err = os.Setenv("OLLAMA_HOST", "0.0.0.0") + if err != nil { + t.Errorf("could not set env var") + } + err = os.Setenv("OLLAMA_PORT", "9999") + if err != nil { + t.Errorf("could not set env var") + } + defer func() { + os.Unsetenv("OLLAMA_HOST") + os.Unsetenv("OLLAMA_PORT") + }() + + host, port, extraOrigins, err := getRunServerParams(serveCmd) + // assertions + if err != nil { + t.Errorf("unexpected error, got %s", err) + } + if host != "0.0.0.0" { + t.Errorf("unexpected host, got %s", host) + } + if port != "9999" { + t.Errorf("unexpected port, got %s", port) + } + if len(extraOrigins) != 0 { + t.Errorf("unexpected origins, got %s", extraOrigins) + } + }) + t.Run("command line args take precedence over env vars", func(t *testing.T) { + cmd := NewCLI() + serveCmd, _, err := cmd.Find([]string{"serve"}) + if err != nil { + t.Errorf("expected serve command, got %s", err) + } + // setup environment variables + err = os.Setenv("OLLAMA_HOST", "0.0.0.0") + if err != nil { + t.Errorf("could not set env var") + } + err = os.Setenv("OLLAMA_PORT", "9999") + if err != nil { + t.Errorf("could not set env var") + } + defer func() { + os.Unsetenv("OLLAMA_HOST") + os.Unsetenv("OLLAMA_PORT") + }() + // now set command flags + serveCmd.Flags().Set("host", "localhost") + serveCmd.Flags().Set("port", "8888") + serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1") + + host, port, extraOrigins, err := getRunServerParams(serveCmd) + if err != nil { + t.Errorf("unexpected error, got %s", err) + } + if host != "localhost" { + t.Errorf("unexpected host, got %s", host) + } + if port != "8888" { + t.Errorf("unexpected port, got %s", port) + } + if len(extraOrigins) != 2 { + t.Errorf("expected two origins, got length %d", len(extraOrigins)) + } + }) +}