origins
This commit is contained in:
parent
4f1afd575d
commit
d1a5227cad
3 changed files with 119 additions and 30 deletions
|
@ -75,9 +75,31 @@ func Host() *url.URL {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||||
|
func Origins() (origins []string) {
|
||||||
|
if s := clean("OLLAMA_ORIGINS"); s != "" {
|
||||||
|
origins = strings.Split(s, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
|
||||||
|
origins = append(origins,
|
||||||
|
fmt.Sprintf("http://%s", origin),
|
||||||
|
fmt.Sprintf("https://%s", origin),
|
||||||
|
fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
|
||||||
|
fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
origins = append(origins,
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
)
|
||||||
|
|
||||||
|
return origins
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Set via OLLAMA_ORIGINS in the environment
|
|
||||||
AllowOrigins []string
|
|
||||||
// Experimental flash attention
|
// Experimental flash attention
|
||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
// Set via OLLAMA_KEEP_ALIVE in the environment
|
// Set via OLLAMA_KEEP_ALIVE in the environment
|
||||||
|
@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar {
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
|
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
|
||||||
|
@ -160,12 +182,6 @@ func Values() map[string]string {
|
||||||
return vals
|
return vals
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultAllowOrigins = []string{
|
|
||||||
"localhost",
|
|
||||||
"127.0.0.1",
|
|
||||||
"0.0.0.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean quotes and spaces from the value
|
// Clean quotes and spaces from the value
|
||||||
func clean(key string) string {
|
func clean(key string) string {
|
||||||
return strings.Trim(os.Getenv(key), "\"' ")
|
return strings.Trim(os.Getenv(key), "\"' ")
|
||||||
|
@ -255,24 +271,6 @@ func LoadConfig() {
|
||||||
NoPrune = true
|
NoPrune = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
|
|
||||||
AllowOrigins = strings.Split(origins, ",")
|
|
||||||
}
|
|
||||||
for _, allowOrigin := range defaultAllowOrigins {
|
|
||||||
AllowOrigins = append(AllowOrigins,
|
|
||||||
fmt.Sprintf("http://%s", allowOrigin),
|
|
||||||
fmt.Sprintf("https://%s", allowOrigin),
|
|
||||||
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
|
|
||||||
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
AllowOrigins = append(AllowOrigins,
|
|
||||||
"app://*",
|
|
||||||
"file://*",
|
|
||||||
"tauri://*",
|
|
||||||
)
|
|
||||||
|
|
||||||
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
||||||
if maxRunners != "" {
|
if maxRunners != "" {
|
||||||
m, err := strconv.Atoi(maxRunners)
|
m, err := strconv.Atoi(maxRunners)
|
||||||
|
|
|
@ -5,10 +5,11 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig(t *testing.T) {
|
func TestSmoke(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_DEBUG", "")
|
t.Setenv("OLLAMA_DEBUG", "")
|
||||||
require.False(t, Debug())
|
require.False(t, Debug())
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ func TestConfig(t *testing.T) {
|
||||||
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientFromEnvironment(t *testing.T) {
|
func TestHost(t *testing.T) {
|
||||||
cases := map[string]struct {
|
cases := map[string]struct {
|
||||||
value string
|
value string
|
||||||
expect string
|
expect string
|
||||||
|
@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOrigins(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
value string
|
||||||
|
expect []string
|
||||||
|
}{
|
||||||
|
{"", []string{
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost",
|
||||||
|
"http://localhost:*",
|
||||||
|
"https://localhost:*",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"http://127.0.0.1:*",
|
||||||
|
"https://127.0.0.1:*",
|
||||||
|
"http://0.0.0.0",
|
||||||
|
"https://0.0.0.0",
|
||||||
|
"http://0.0.0.0:*",
|
||||||
|
"https://0.0.0.0:*",
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
}},
|
||||||
|
{"http://10.0.0.1", []string{
|
||||||
|
"http://10.0.0.1",
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost",
|
||||||
|
"http://localhost:*",
|
||||||
|
"https://localhost:*",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"http://127.0.0.1:*",
|
||||||
|
"https://127.0.0.1:*",
|
||||||
|
"http://0.0.0.0",
|
||||||
|
"https://0.0.0.0",
|
||||||
|
"http://0.0.0.0:*",
|
||||||
|
"https://0.0.0.0:*",
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
}},
|
||||||
|
{"http://172.16.0.1,https://192.168.0.1", []string{
|
||||||
|
"http://172.16.0.1",
|
||||||
|
"https://192.168.0.1",
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost",
|
||||||
|
"http://localhost:*",
|
||||||
|
"https://localhost:*",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"http://127.0.0.1:*",
|
||||||
|
"https://127.0.0.1:*",
|
||||||
|
"http://0.0.0.0",
|
||||||
|
"https://0.0.0.0",
|
||||||
|
"http://0.0.0.0:*",
|
||||||
|
"https://0.0.0.0:*",
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
}},
|
||||||
|
{"http://totally.safe,http://definitely.legit", []string{
|
||||||
|
"http://totally.safe",
|
||||||
|
"http://definitely.legit",
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost",
|
||||||
|
"http://localhost:*",
|
||||||
|
"https://localhost:*",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"http://127.0.0.1:*",
|
||||||
|
"https://127.0.0.1:*",
|
||||||
|
"http://0.0.0.0",
|
||||||
|
"https://0.0.0.0",
|
||||||
|
"http://0.0.0.0:*",
|
||||||
|
"https://0.0.0.0:*",
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.value, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
|
||||||
|
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
for _, prop := range openAIProperties {
|
for _, prop := range openAIProperties {
|
||||||
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
||||||
}
|
}
|
||||||
config.AllowOrigins = envconfig.AllowOrigins
|
config.AllowOrigins = envconfig.Origins()
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(
|
r.Use(
|
||||||
|
|
Loading…
Reference in a new issue