Provide variable ggml for TestLoad
This commit is contained in:
parent
284e02bed0
commit
ceb0e26e5e
1 changed files with 4 additions and 3 deletions
|
@ -47,6 +47,7 @@ func TestLoad(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
|
ggml := nil // value not used in tests
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
model: &Model{ModelPath: "foo"},
|
model: &Model{ModelPath: "foo"},
|
||||||
|
@ -59,7 +60,7 @@ func TestLoad(t *testing.T) {
|
||||||
return nil, fmt.Errorf("something failed to load model blah")
|
return nil, fmt.Errorf("something failed to load model blah")
|
||||||
}
|
}
|
||||||
gpus := gpu.GpuInfoList{}
|
gpus := gpu.GpuInfoList{}
|
||||||
s.load(req, nil, gpus)
|
s.load(req, ggml, gpus)
|
||||||
require.Len(t, req.successCh, 0)
|
require.Len(t, req.successCh, 0)
|
||||||
require.Len(t, req.errCh, 1)
|
require.Len(t, req.errCh, 1)
|
||||||
require.Len(t, s.loaded, 0)
|
require.Len(t, s.loaded, 0)
|
||||||
|
@ -70,7 +71,7 @@ func TestLoad(t *testing.T) {
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
s.load(req, nil, gpus)
|
s.load(req, ggml, gpus)
|
||||||
select {
|
select {
|
||||||
case err := <-req.errCh:
|
case err := <-req.errCh:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -82,7 +83,7 @@ func TestLoad(t *testing.T) {
|
||||||
|
|
||||||
req.model.ModelPath = "dummy_model_path"
|
req.model.ModelPath = "dummy_model_path"
|
||||||
server.waitResp = fmt.Errorf("wait failure")
|
server.waitResp = fmt.Errorf("wait failure")
|
||||||
s.load(req, nil, gpus)
|
s.load(req, ggml, gpus)
|
||||||
select {
|
select {
|
||||||
case err := <-req.errCh:
|
case err := <-req.errCh:
|
||||||
require.Contains(t, err.Error(), "wait failure")
|
require.Contains(t, err.Error(), "wait failure")
|
||||||
|
|
Loading…
Reference in a new issue