pass flags to serve to allow setting allowed-origins + host and port

This commit is contained in:
Bruce MacDonald 2023-08-08 10:41:42 -04:00 committed by GitHub
commit 34a13a9d05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 10 deletions

View file

@ -513,15 +513,36 @@ func generateBatch(cmd *cobra.Command, model string) error {
return nil
}
func RunServer(_ *cobra.Command, _ []string) error {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
host = "127.0.0.1"
// getRunServerParams takes a command and the environment variables and returns the correct params
// given the order of precedence: command line args (highest), environment variables, defaults (lowest)
func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []string, err error) {
host = os.Getenv("OLLAMA_HOST")
hostFlag := cmd.Flags().Lookup("host")
if hostFlag == nil {
return "", "", nil, errors.New("host unset")
}
if hostFlag.Changed || host == "" {
host = hostFlag.Value.String()
}
port = os.Getenv("OLLAMA_PORT")
portFlag := cmd.Flags().Lookup("port")
if portFlag == nil {
return "", "", nil, errors.New("port unset")
}
if portFlag.Changed || port == "" {
port = portFlag.Value.String()
}
extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins")
if err != nil {
return "", "", nil, err
}
return host, port, extraOrigins, nil
}
port := os.Getenv("OLLAMA_PORT")
if port == "" {
port = "11434"
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, extraOrigins, err := getRunServerParams(cmd)
if err != nil {
return err
}
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
@ -529,7 +550,7 @@ func RunServer(_ *cobra.Command, _ []string) error {
return err
}
return server.Serve(ln)
return server.Serve(ln, extraOrigins)
}
func startMacApp(client *api.Client) error {
@ -621,6 +642,10 @@ func NewCLI() *cobra.Command {
RunE: RunServer,
}
serveCmd.Flags().String("port", "11434", "Port to listen on, may also use OLLAMA_PORT environment variable")
serveCmd.Flags().String("host", "127.0.0.1", "Host listen address, may also use OLLAMA_HOST environment variable")
serveCmd.Flags().StringSlice("allowed-origins", []string{}, "Additional allowed CORS origins (outside of localhost), specify as comma-separated list")
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",

103
cmd/cmd_test.go Normal file
View file

@ -0,0 +1,103 @@
package cmd
import (
"os"
"testing"
)
func TestGetRunServerParams(t *testing.T) {
t.Run("default values", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
host, port, extraOrigins, err := getRunServerParams(serveCmd)
// assertions
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "127.0.0.1" {
t.Errorf("unexpected host, got %s", host)
}
if port != "11434" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 0 {
t.Errorf("unexpected origins, got %s", extraOrigins)
}
})
t.Run("environment variables take precedence over default", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
// setup environment variables
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
if err != nil {
t.Errorf("could not set env var")
}
err = os.Setenv("OLLAMA_PORT", "9999")
if err != nil {
t.Errorf("could not set env var")
}
defer func() {
os.Unsetenv("OLLAMA_HOST")
os.Unsetenv("OLLAMA_PORT")
}()
host, port, extraOrigins, err := getRunServerParams(serveCmd)
// assertions
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "0.0.0.0" {
t.Errorf("unexpected host, got %s", host)
}
if port != "9999" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 0 {
t.Errorf("unexpected origins, got %s", extraOrigins)
}
})
t.Run("command line args take precedence over env vars", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
// setup environment variables
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
if err != nil {
t.Errorf("could not set env var")
}
err = os.Setenv("OLLAMA_PORT", "9999")
if err != nil {
t.Errorf("could not set env var")
}
defer func() {
os.Unsetenv("OLLAMA_HOST")
os.Unsetenv("OLLAMA_PORT")
}()
// now set command flags
serveCmd.Flags().Set("host", "localhost")
serveCmd.Flags().Set("port", "8888")
serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1")
host, port, extraOrigins, err := getRunServerParams(serveCmd)
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "localhost" {
t.Errorf("unexpected host, got %s", host)
}
if port != "8888" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 2 {
t.Errorf("expected two origins, got length %d", len(extraOrigins))
}
})
}

View file

@ -301,11 +301,11 @@ func CopyModelHandler(c *gin.Context) {
}
}
func Serve(ln net.Listener) error {
func Serve(ln net.Listener, extraOrigins []string) error {
config := cors.DefaultConfig()
config.AllowWildcard = true
// only allow http/https from localhost
config.AllowOrigins = []string{
allowedOrigins := []string{
"http://localhost",
"http://localhost:*",
"https://localhost",
@ -315,6 +315,8 @@ func Serve(ln net.Listener) error {
"https://127.0.0.1",
"https://127.0.0.1:*",
}
allowedOrigins = append(allowedOrigins, extraOrigins...)
config.AllowOrigins = allowedOrigins
r := gin.Default()
r.Use(cors.New(config))