diff --git a/server/sched_test.go b/server/sched_test.go index 1fa2a4a2..e451d84a 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -47,6 +47,7 @@ func TestLoad(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) defer done() s := InitScheduler(ctx) + ggml := nil // value not used in tests req := &LlmRequest{ ctx: ctx, model: &Model{ModelPath: "foo"}, @@ -59,7 +60,7 @@ func TestLoad(t *testing.T) { return nil, fmt.Errorf("something failed to load model blah") } gpus := gpu.GpuInfoList{} - s.load(req, nil, gpus) + s.load(req, ggml, gpus) require.Len(t, req.successCh, 0) require.Len(t, req.errCh, 1) require.Len(t, s.loaded, 0) @@ -70,7 +71,7 @@ func TestLoad(t *testing.T) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { return server, nil } - s.load(req, nil, gpus) + s.load(req, ggml, gpus) select { case err := <-req.errCh: require.NoError(t, err) @@ -82,7 +83,7 @@ func TestLoad(t *testing.T) { req.model.ModelPath = "dummy_model_path" server.waitResp = fmt.Errorf("wait failure") - s.load(req, nil, gpus) + s.load(req, ggml, gpus) select { case err := <-req.errCh: require.Contains(t, err.Error(), "wait failure")