Add unit test of API routes (#1528)

This commit is contained in:
Patrick Devine 2023-12-14 16:47:40 -08:00 committed by GitHub
parent 6e16098a60
commit 630518f0d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 30 deletions

View file

@ -1035,12 +1035,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return err return err
} }
var origins []string return server.Serve(ln)
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
origins = strings.Split(o, ",")
}
return server.Serve(ln, origins)
} }
func getImageData(filePath string) ([]byte, error) { func getImageData(filePath string) ([]byte, error) {

3
go.mod
View file

@ -7,11 +7,14 @@ require (
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.3
golang.org/x/sync v0.3.0 golang.org/x/sync v0.3.0
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
) )

View file

@ -32,6 +32,10 @@ import (
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct {
WorkDir string
}
func init() { func init() {
switch mode { switch mode {
case gin.DebugMode: case gin.DebugMode:
@ -800,27 +804,27 @@ var defaultAllowOrigins = []string{
"0.0.0.0", "0.0.0.0",
} }
func Serve(ln net.Listener, allowOrigins []string) error { func NewServer() (*Server, error) {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { workDir, err := os.MkdirTemp("", "ollama")
// clean up unused layers and manifests if err != nil {
if err := PruneLayers(); err != nil { return nil, err
return err }
}
manifestsPath, err := GetManifestPath() return &Server{
if err != nil { WorkDir: workDir,
return err }, nil
} }
if err := PruneDirectory(manifestsPath); err != nil { func (s *Server) GenerateRoutes() http.Handler {
return err var origins []string
} if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
origins = strings.Split(o, ",")
} }
config := cors.DefaultConfig() config := cors.DefaultConfig()
config.AllowWildcard = true config.AllowWildcard = true
config.AllowOrigins = allowOrigins config.AllowOrigins = origins
for _, allowOrigin := range defaultAllowOrigins { for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins, config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin), fmt.Sprintf("http://%s", allowOrigin),
@ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error {
) )
} }
workDir, err := os.MkdirTemp("", "ollama")
if err != nil {
return err
}
defer os.RemoveAll(workDir)
r := gin.Default() r := gin.Default()
r.Use( r.Use(
cors.New(config), cors.New(config),
func(c *gin.Context) { func(c *gin.Context) {
c.Set("workDir", workDir) c.Set("workDir", s.WorkDir)
c.Next() c.Next()
}, },
) )
@ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error {
}) })
} }
return r
}
func Serve(ln net.Listener) error {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}
s, err := NewServer()
if err != nil {
return err
}
r := s.GenerateRoutes()
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
s := &http.Server{ srvr := &http.Server{
Handler: r, Handler: r,
} }
@ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if loaded.runner != nil { if loaded.runner != nil {
loaded.runner.Close() loaded.runner.Close()
} }
os.RemoveAll(workDir) os.RemoveAll(s.WorkDir)
os.Exit(0) os.Exit(0)
}() }()
@ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
} }
} }
return s.Serve(ln) return srvr.Serve(ln)
} }
func waitForStream(c *gin.Context, ch chan interface{}) { func waitForStream(c *gin.Context, ch chan interface{}) {

70
server/routes_test.go Normal file
View file

@ -0,0 +1,70 @@
package server
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func setupServer(t *testing.T) (*Server, error) {
t.Helper()
return NewServer()
}
func Test_Routes(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *http.Response)
}
testCases := []testCase{
{
Name: "Version Handler",
Method: http.MethodGet,
Path: "/api/version",
Setup: func(t *testing.T, req *http.Request) {
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, `{"version":"0.0.0"}`, string(body))
},
},
}
s, err := setupServer(t)
assert.Nil(t, err)
router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)
for _, tc := range testCases {
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err)
if tc.Expected != nil {
tc.Expected(t, resp)
}
}
}