Add unit test of API routes (#1528)
This commit is contained in:
parent
6e16098a60
commit
630518f0d9
4 changed files with 122 additions and 30 deletions
|
@ -1035,12 +1035,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var origins []string
|
return server.Serve(ln)
|
||||||
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
|
|
||||||
origins = strings.Split(o, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
return server.Serve(ln, origins)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getImageData(filePath string) ([]byte, error) {
|
func getImageData(filePath string) ([]byte, error) {
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -7,11 +7,14 @@ require (
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/olekukonko/tablewriter v0.0.5
|
github.com/olekukonko/tablewriter v0.0.5
|
||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
|
github.com/stretchr/testify v1.8.3
|
||||||
golang.org/x/sync v0.3.0
|
golang.org/x/sync v0.3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,10 @@ import (
|
||||||
|
|
||||||
var mode string = gin.DebugMode
|
var mode string = gin.DebugMode
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
WorkDir string
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
switch mode {
|
switch mode {
|
||||||
case gin.DebugMode:
|
case gin.DebugMode:
|
||||||
|
@ -800,27 +804,27 @@ var defaultAllowOrigins = []string{
|
||||||
"0.0.0.0",
|
"0.0.0.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
func Serve(ln net.Listener, allowOrigins []string) error {
|
func NewServer() (*Server, error) {
|
||||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
workDir, err := os.MkdirTemp("", "ollama")
|
||||||
// clean up unused layers and manifests
|
if err != nil {
|
||||||
if err := PruneLayers(); err != nil {
|
return nil, err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
manifestsPath, err := GetManifestPath()
|
return &Server{
|
||||||
if err != nil {
|
WorkDir: workDir,
|
||||||
return err
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := PruneDirectory(manifestsPath); err != nil {
|
func (s *Server) GenerateRoutes() http.Handler {
|
||||||
return err
|
var origins []string
|
||||||
}
|
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
|
||||||
|
origins = strings.Split(o, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := cors.DefaultConfig()
|
config := cors.DefaultConfig()
|
||||||
config.AllowWildcard = true
|
config.AllowWildcard = true
|
||||||
|
|
||||||
config.AllowOrigins = allowOrigins
|
config.AllowOrigins = origins
|
||||||
for _, allowOrigin := range defaultAllowOrigins {
|
for _, allowOrigin := range defaultAllowOrigins {
|
||||||
config.AllowOrigins = append(config.AllowOrigins,
|
config.AllowOrigins = append(config.AllowOrigins,
|
||||||
fmt.Sprintf("http://%s", allowOrigin),
|
fmt.Sprintf("http://%s", allowOrigin),
|
||||||
|
@ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
workDir, err := os.MkdirTemp("", "ollama")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(workDir)
|
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(
|
r.Use(
|
||||||
cors.New(config),
|
cors.New(config),
|
||||||
func(c *gin.Context) {
|
func(c *gin.Context) {
|
||||||
c.Set("workDir", workDir)
|
c.Set("workDir", s.WorkDir)
|
||||||
c.Next()
|
c.Next()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func Serve(ln net.Listener) error {
|
||||||
|
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||||
|
// clean up unused layers and manifests
|
||||||
|
if err := PruneLayers(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
manifestsPath, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PruneDirectory(manifestsPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := NewServer()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r := s.GenerateRoutes()
|
||||||
|
|
||||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||||
s := &http.Server{
|
srvr := &http.Server{
|
||||||
Handler: r,
|
Handler: r,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
if loaded.runner != nil {
|
if loaded.runner != nil {
|
||||||
loaded.runner.Close()
|
loaded.runner.Close()
|
||||||
}
|
}
|
||||||
os.RemoveAll(workDir)
|
os.RemoveAll(s.WorkDir)
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.Serve(ln)
|
return srvr.Serve(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForStream(c *gin.Context, ch chan interface{}) {
|
func waitForStream(c *gin.Context, ch chan interface{}) {
|
||||||
|
|
70
server/routes_test.go
Normal file
70
server/routes_test.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupServer(t *testing.T) (*Server, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return NewServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Routes(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, resp *http.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "Version Handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/version",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, `{"version":"0.0.0"}`, string(body))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := setupServer(t)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
router := s.GenerateRoutes()
|
||||||
|
|
||||||
|
httpSrv := httptest.NewServer(router)
|
||||||
|
t.Cleanup(httpSrv.Close)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
u := httpSrv.URL + tc.Path
|
||||||
|
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpSrv.Client().Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
if tc.Expected != nil {
|
||||||
|
tc.Expected(t, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue