This commit is contained in:
Michael Yang 2024-07-03 17:02:07 -07:00
parent 4f1afd575d
commit d1a5227cad
3 changed files with 119 additions and 30 deletions

View file

@ -75,9 +75,31 @@ func Host() *url.URL {
}
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := clean("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
origins = append(origins,
fmt.Sprintf("http://%s", origin),
fmt.Sprintf("https://%s", origin),
fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
)
}
origins = append(origins,
"app://*",
"file://*",
"tauri://*",
)
return origins
}
var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Experimental flash attention
FlashAttention bool
// Set via OLLAMA_KEEP_ALIVE in the environment
@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
@ -160,12 +182,6 @@ func Values() map[string]string {
return vals
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
// Clean quotes and spaces from the value
func clean(key string) string {
return strings.Trim(os.Getenv(key), "\"' ")
@ -255,24 +271,6 @@ func LoadConfig() {
NoPrune = true
}
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
AllowOrigins = strings.Split(origins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
)
}
AllowOrigins = append(AllowOrigins,
"app://*",
"file://*",
"tauri://*",
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)

View file

@ -5,10 +5,11 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
func TestSmoke(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "")
require.False(t, Debug())
@ -38,7 +39,7 @@ func TestConfig(t *testing.T) {
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
}
func TestClientFromEnvironment(t *testing.T) {
func TestHost(t *testing.T) {
cases := map[string]struct {
value string
expect string
@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
func TestOrigins(t *testing.T) {
cases := []struct {
value string
expect []string
}{
{"", []string{
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://10.0.0.1", []string{
"http://10.0.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
"https://192.168.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
"http://definitely.legit",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
}
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})
}
}

View file

@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler {
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
}
config.AllowOrigins = envconfig.AllowOrigins
config.AllowOrigins = envconfig.Origins()
r := gin.Default()
r.Use(