This commit is contained in:
Michael Yang 2024-07-03 17:02:07 -07:00
parent 4f1afd575d
commit d1a5227cad
3 changed files with 119 additions and 30 deletions

View file

@ -75,9 +75,31 @@ func Host() *url.URL {
} }
} }
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := clean("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
origins = append(origins,
fmt.Sprintf("http://%s", origin),
fmt.Sprintf("https://%s", origin),
fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
)
}
origins = append(origins,
"app://*",
"file://*",
"tauri://*",
)
return origins
}
var ( var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Experimental flash attention // Experimental flash attention
FlashAttention bool FlashAttention bool
// Set via OLLAMA_KEEP_ALIVE in the environment // Set via OLLAMA_KEEP_ALIVE in the environment
@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
@ -160,12 +182,6 @@ func Values() map[string]string {
return vals return vals
} }
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
// Clean quotes and spaces from the value // Clean quotes and spaces from the value
func clean(key string) string { func clean(key string) string {
return strings.Trim(os.Getenv(key), "\"' ") return strings.Trim(os.Getenv(key), "\"' ")
@ -255,24 +271,6 @@ func LoadConfig() {
NoPrune = true NoPrune = true
} }
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
AllowOrigins = strings.Split(origins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
)
}
AllowOrigins = append(AllowOrigins,
"app://*",
"file://*",
"tauri://*",
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" { if maxRunners != "" {
m, err := strconv.Atoi(maxRunners) m, err := strconv.Atoi(maxRunners)

View file

@ -5,10 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestConfig(t *testing.T) { func TestSmoke(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "") t.Setenv("OLLAMA_DEBUG", "")
require.False(t, Debug()) require.False(t, Debug())
@ -38,7 +39,7 @@ func TestConfig(t *testing.T) {
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
} }
func TestClientFromEnvironment(t *testing.T) { func TestHost(t *testing.T) {
cases := map[string]struct { cases := map[string]struct {
value string value string
expect string expect string
@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) {
}) })
} }
} }
func TestOrigins(t *testing.T) {
cases := []struct {
value string
expect []string
}{
{"", []string{
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://10.0.0.1", []string{
"http://10.0.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
"https://192.168.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
"http://definitely.legit",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
}
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})
}
}

View file

@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler {
for _, prop := range openAIProperties { for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
} }
config.AllowOrigins = envconfig.AllowOrigins config.AllowOrigins = envconfig.Origins()
r := gin.Default() r := gin.Default()
r.Use( r.Use(