add allowed host middleware and remove workDir
middleware (#3018)
This commit is contained in:
parent
ecc133d843
commit
fc8c044584
2 changed files with 61 additions and 26 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
|
@ -35,7 +36,7 @@ import (
|
|||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
WorkDir string
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{
|
|||
"0.0.0.0",
|
||||
}
|
||||
|
||||
func NewServer() (*Server, error) {
|
||||
workDir, err := os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func allowedHost(host string) bool {
|
||||
if host == "" || host == "localhost" {
|
||||
return true
|
||||
}
|
||||
|
||||
return &Server{
|
||||
WorkDir: workDir,
|
||||
}, nil
|
||||
if hostname, err := os.Hostname(); err == nil && host == hostname {
|
||||
return true
|
||||
}
|
||||
|
||||
var tlds = []string{
|
||||
".localhost",
|
||||
".local",
|
||||
".internal",
|
||||
}
|
||||
|
||||
for _, tld := range tlds {
|
||||
if strings.HasSuffix(host, "."+tld) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if addr == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(c.Request.Host)
|
||||
if err != nil {
|
||||
host = c.Request.Host
|
||||
}
|
||||
|
||||
if allowedHost(host) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
|
@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||
r := gin.Default()
|
||||
r.Use(
|
||||
cors.New(config),
|
||||
func(c *gin.Context) {
|
||||
c.Set("workDir", s.WorkDir)
|
||||
c.Next()
|
||||
},
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
r.POST("/api/pull", PullModelHandler)
|
||||
|
@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error {
|
|||
}
|
||||
}
|
||||
|
||||
s, err := NewServer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := &Server{addr: ln.Addr()}
|
||||
r := s.GenerateRoutes()
|
||||
|
||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||
|
@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error {
|
|||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
}
|
||||
os.RemoveAll(s.WorkDir)
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
|
|
|
@ -21,12 +21,6 @@ import (
|
|||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
func setupServer(t *testing.T) (*Server, error) {
|
||||
t.Helper()
|
||||
|
||||
return NewServer()
|
||||
}
|
||||
|
||||
func Test_Routes(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
|
@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
s, err := setupServer(t)
|
||||
assert.Nil(t, err)
|
||||
|
||||
s := Server{}
|
||||
router := s.GenerateRoutes()
|
||||
|
||||
httpSrv := httptest.NewServer(router)
|
||||
|
|
Loading…
Reference in a new issue