add magic header for unit tests (#1558)

This commit is contained in:
Patrick Devine 2023-12-18 10:41:02 -08:00 committed by GitHub
parent b85982eb91
commit 3948c6ea06
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -33,17 +32,29 @@ func Test_Routes(t *testing.T) {
Setup func(t *testing.T, req *http.Request) Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *http.Response) Expected func(t *testing.T, resp *http.Response)
} }
var tempModelFile string
createTestFile := func(t *testing.T, name string) string {
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
_, err = f.Write([]byte("GGUF"))
assert.Nil(t, err)
_, err = f.Write([]byte{0x2, 0})
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
f, err := os.CreateTemp("", "ollama-model") fname := createTestFile(t, "ollama-model")
assert.Nil(t, err)
defer os.RemoveAll(f.Name())
modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name())) modelfile := strings.NewReader(fmt.Sprintf("FROM %s", fname))
commands, err := parser.Parse(modelfile) commands, err := parser.Parse(modelfile)
assert.Nil(t, err) assert.Nil(t, err)
fn := func(resp api.ProgressResponse) {} fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", commands, fn) err = CreateModel(context.TODO(), name, "", commands, fn)
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -107,9 +118,9 @@ func Test_Routes(t *testing.T) {
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/create", Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
f, err := os.CreateTemp("", "ollama-model") f, err := os.CreateTemp(t.TempDir(), "ollama-model")
assert.Nil(t, err) assert.Nil(t, err)
tempModelFile = f.Name() defer f.Close()
stream := false stream := false
createReq := api.CreateRequest{ createReq := api.CreateRequest{
@ -123,8 +134,6 @@ func Test_Routes(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
os.RemoveAll(tempModelFile)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType) assert.Equal(t, "application/json", contentType)
_, err := io.ReadAll(resp.Body) _, err := io.ReadAll(resp.Body)
@ -173,6 +182,7 @@ func Test_Routes(t *testing.T) {
os.Setenv("OLLAMA_MODELS", workDir) os.Setenv("OLLAMA_MODELS", workDir)
for _, tc := range testCases { for _, tc := range testCases {
t.Logf("Running Test: [%s]", tc.Name)
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err) assert.Nil(t, err)