add allowed host middleware and remove workDir middleware (#3018)

This commit is contained in:
Jeffrey Morgan 2024-03-08 22:23:47 -08:00 committed by GitHub
parent ecc133d843
commit fc8c044584
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 61 additions and 26 deletions

View file

@ -10,6 +10,7 @@ import (
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"net/netip"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@ -35,7 +36,7 @@ import (
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
WorkDir string addr net.Addr
} }
func init() { func init() {
@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{
"0.0.0.0", "0.0.0.0",
} }
func NewServer() (*Server, error) { func allowedHost(host string) bool {
workDir, err := os.MkdirTemp("", "ollama") if host == "" || host == "localhost" {
if err != nil { return true
return nil, err
} }
return &Server{ if hostname, err := os.Hostname(); err == nil && host == hostname {
WorkDir: workDir, return true
}, nil }
var tlds = []string{
".localhost",
".local",
".internal",
}
for _, tld := range tlds {
if strings.HasSuffix(host, "."+tld) {
return true
}
}
return false
}
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
c.Next()
return
}
if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() {
c.Next()
return
}
if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() {
c.Next()
return
}
if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() {
c.Next()
return
}
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
if allowedHost(host) {
c.Next()
return
}
c.AbortWithStatus(http.StatusForbidden)
}
} }
func (s *Server) GenerateRoutes() http.Handler { func (s *Server) GenerateRoutes() http.Handler {
@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r := gin.Default() r := gin.Default()
r.Use( r.Use(
cors.New(config), cors.New(config),
func(c *gin.Context) { allowedHostsMiddleware(s.addr),
c.Set("workDir", s.WorkDir)
c.Next()
},
) )
r.POST("/api/pull", PullModelHandler) r.POST("/api/pull", PullModelHandler)
@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error {
} }
} }
s, err := NewServer() s := &Server{addr: ln.Addr()}
if err != nil {
return err
}
r := s.GenerateRoutes() r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error {
if loaded.runner != nil { if loaded.runner != nil {
loaded.runner.Close() loaded.runner.Close()
} }
os.RemoveAll(s.WorkDir)
os.Exit(0) os.Exit(0)
}() }()

View file

@ -21,12 +21,6 @@ import (
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
func setupServer(t *testing.T) (*Server, error) {
t.Helper()
return NewServer()
}
func Test_Routes(t *testing.T) { func Test_Routes(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) {
}, },
} }
s, err := setupServer(t) s := Server{}
assert.Nil(t, err)
router := s.GenerateRoutes() router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router) httpSrv := httptest.NewServer(router)