diff --git a/envconfig/config.go b/envconfig/config.go index 23f93270..7ae521ab 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -75,9 +75,31 @@ func Host() *url.URL { } } +// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. +func Origins() (origins []string) { + if s := clean("OLLAMA_ORIGINS"); s != "" { + origins = strings.Split(s, ",") + } + + for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} { + origins = append(origins, + fmt.Sprintf("http://%s", origin), + fmt.Sprintf("https://%s", origin), + fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")), + fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")), + ) + } + + origins = append(origins, + "app://*", + "file://*", + "tauri://*", + ) + + return origins +} + var ( - // Set via OLLAMA_ORIGINS in the environment - AllowOrigins []string // Experimental flash attention FlashAttention bool // Set via OLLAMA_KEEP_ALIVE in the environment @@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, + "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, @@ -160,12 +182,6 @@ func Values() map[string]string { return vals } -var defaultAllowOrigins = []string{ - "localhost", - "127.0.0.1", - "0.0.0.0", -} - // Clean quotes and spaces from the value func clean(key string) string { return strings.Trim(os.Getenv(key), "\"' ") @@ -255,24 +271,6 @@ func LoadConfig() { NoPrune = true } - if origins := clean("OLLAMA_ORIGINS"); origins != "" { - AllowOrigins = strings.Split(origins, ",") - } - for _, allowOrigin := range defaultAllowOrigins { - AllowOrigins = append(AllowOrigins, - fmt.Sprintf("http://%s", allowOrigin), - fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), - fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), - ) - } - - AllowOrigins = append(AllowOrigins, - "app://*", - "file://*", - "tauri://*", - ) - maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") if maxRunners != "" { m, err := strconv.Atoi(maxRunners) diff --git a/envconfig/config_test.go b/envconfig/config_test.go index af89e7b7..dc65ef70 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" ) -func TestConfig(t *testing.T) { +func TestSmoke(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "") require.False(t, Debug()) @@ -38,7 +39,7 @@ func TestConfig(t *testing.T) { require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) } -func TestClientFromEnvironment(t *testing.T) { +func TestHost(t *testing.T) { cases := map[string]struct { value string expect string @@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) { }) } } + +func TestOrigins(t *testing.T) { + cases := []struct { + value string + expect []string + }{ + {"", []string{ + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://10.0.0.1", []string{ + "http://10.0.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://172.16.0.1,https://192.168.0.1", []string{ + "http://172.16.0.1", + "https://192.168.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://totally.safe,http://definitely.legit", []string{ + "http://totally.safe", + "http://definitely.legit", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + } + for _, tt := range cases { + t.Run(tt.value, func(t *testing.T) { + t.Setenv("OLLAMA_ORIGINS", tt.value) + + if diff := cmp.Diff(Origins(), tt.expect); diff != "" { + t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index c049421b..07898d9b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler { for _, prop := range openAIProperties { config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) } - config.AllowOrigins = envconfig.AllowOrigins + config.AllowOrigins = envconfig.Origins() r := gin.Default() r.Use(