Ollama ps command for showing currently loaded models (#4327)

This commit is contained in:
Patrick Devine 2024-05-13 17:17:36 -07:00 committed by GitHub
parent 9eed4a90ce
commit 6845988807
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 193 additions and 50 deletions

View file

@ -354,6 +354,15 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil return &lr, nil
} }
// List running models.
func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
var lr ListResponse
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
return nil, err
}
return &lr, nil
}
// Copy copies a model - creating a model with another name from an existing // Copy copies a model - creating a model with another name from an existing
// model. // model.
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {

View file

@ -289,10 +289,12 @@ type ListResponse struct {
type ModelResponse struct { type ModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"` ModifiedAt time.Time `json:"modified_at,omitempty"`
Size int64 `json:"size"` Size int64 `json:"size"`
Digest string `json:"digest"` Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
SizeVRAM int64 `json:"size_vram,omitempty"`
} }
type TokenResponse struct { type TokenResponse struct {

View file

@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -324,6 +325,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.Format = format opts.Format = format
keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil {
return err
}
if keepAlive != "" {
d, err := time.ParseDuration(keepAlive)
if err != nil {
return err
}
opts.KeepAlive = &api.Duration{Duration: d}
}
prompts := args[1:] prompts := args[1:]
// prepend stdin to the prompt if provided // prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) { if !term.IsTerminal(int(os.Stdin.Fd())) {
@ -496,6 +509,52 @@ func ListHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func ListRunningHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
models, err := client.ListRunning(cmd.Context())
if err != nil {
return err
}
var data [][]string
for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
var procStr string
switch {
case m.SizeVRAM == 0:
procStr = "100% CPU"
case m.SizeVRAM == m.Size:
procStr = "100% GPU"
case m.SizeVRAM > m.Size || m.Size == 0:
procStr = "Unknown"
default:
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
}
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
}
}
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderLine(false)
table.SetBorder(false)
table.SetNoWhiteSpace(true)
table.SetTablePadding("\t")
table.AppendBulk(data)
table.Render()
return nil
}
func DeleteHandler(cmd *cobra.Command, args []string) error { func DeleteHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@ -672,6 +731,7 @@ type runOptions struct {
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool MultiModal bool
KeepAlive *api.Duration
} }
type displayResponseState struct { type displayResponseState struct {
@ -766,6 +826,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Options: opts.Options, Options: opts.Options,
} }
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil { if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return nil, nil return nil, nil
@ -1075,6 +1139,7 @@ func NewCLI() *cobra.Command {
RunE: RunHandler, RunE: RunHandler,
} }
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response") runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
@ -1123,6 +1188,14 @@ Environment Variables:
PreRunE: checkServerHeartbeat, PreRunE: checkServerHeartbeat,
RunE: ListHandler, RunE: ListHandler,
} }
psCmd := &cobra.Command{
Use: "ps",
Short: "List running models",
PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler,
}
copyCmd := &cobra.Command{ copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION", Use: "cp SOURCE DESTINATION",
Short: "Copy a model", Short: "Copy a model",
@ -1146,6 +1219,7 @@ Environment Variables:
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
psCmd,
copyCmd, copyCmd,
deleteCmd, deleteCmd,
} { } {
@ -1160,6 +1234,7 @@ Environment Variables:
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
psCmd,
copyCmd, copyCmd,
deleteCmd, deleteCmd,
) )

View file

@ -56,6 +56,11 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
Model: opts.Model, Model: opts.Model,
Messages: []api.Message{}, Messages: []api.Message{},
} }
if opts.KeepAlive != nil {
chatReq.KeepAlive = opts.KeepAlive
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear() p.StopAndClear()
if len(opts.Messages) > 0 { if len(opts.Messages) > 0 {

View file

@ -60,7 +60,9 @@ func humanTime(t time.Time, zeroValue string) string {
} }
delta := time.Since(t) delta := time.Since(t)
if delta < 0 { if int(delta.Hours())/24/365 < -20 {
return "Forever"
} else if delta < 0 {
return humanDuration(-delta) + " from now" return humanDuration(-delta) + " from now"
} }

View file

@ -32,4 +32,14 @@ func TestHumanTime(t *testing.T) {
v := now.Add(800 * time.Millisecond) v := now.Add(800 * time.Millisecond)
assertEqual(t, HumanTime(v, ""), "Less than a second from now") assertEqual(t, HumanTime(v, ""), "Less than a second from now")
}) })
t.Run("time way in the future", func(t *testing.T) {
v := now.Add(24 * time.Hour * 365 * 200)
assertEqual(t, HumanTime(v, ""), "Forever")
})
t.Run("time way in the future lowercase", func(t *testing.T) {
v := now.Add(24 * time.Hour * 365 * 200)
assertEqual(t, HumanTimeLower(v, ""), "forever")
})
} }

View file

@ -38,6 +38,7 @@ type LlamaServer interface {
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
EstimatedVRAM() uint64 EstimatedVRAM() uint64
EstimatedTotal() uint64
} }
// llmServer is an instance of the llama.cpp server // llmServer is an instance of the llama.cpp server
@ -955,6 +956,10 @@ func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM return s.estimatedVRAM
} }
func (s *llmServer) EstimatedTotal() uint64 {
return s.estimatedTotal
}
func parseDurationMs(ms float64) time.Duration { func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil { if err != nil {

View file

@ -979,6 +979,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/show", s.ShowModelHandler) r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.ProcessHandler)
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
@ -1137,6 +1138,34 @@ func streamResponse(c *gin.Context, ch chan any) {
}) })
} }
func (s *Server) ProcessHandler(c *gin.Context) {
models := []api.ModelResponse{}
for _, v := range s.sched.loaded {
model := v.model
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
mr := api.ModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: int64(v.estimatedTotal),
SizeVRAM: int64(v.estimatedVRAM),
Digest: model.Digest,
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
models = append(models, mr)
}
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {

View file

@ -177,7 +177,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
// Trigger an expiration to unload once it's done // Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock() runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount) slog.Debug("resetting model to expire immediately to make room", "modelPath", runnerToExpire.modelPath, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil { if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop() runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil runnerToExpire.expireTimer = nil
@ -190,13 +190,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
// Wait for the unload to happen // Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for // Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed. // a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model) slog.Debug("waiting for pending requests to complete and unload to occur", "modelPath", runnerToExpire.modelPath)
select { select {
case <-ctx.Done(): case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop") slog.Debug("shutting down scheduler pending loop")
return return
case <-s.unloadedCh: case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model) slog.Debug("unload completed", "modelPath", runnerToExpire.modelPath)
continue continue
} }
} }
@ -219,23 +219,23 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
runner := s.loaded[finished.model.ModelPath] runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock() s.loadedMu.Unlock()
if runner == nil { if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath) slog.Error("finished requeset signal received after model unloaded", "modelPath", finished.model.ModelPath)
continue continue
} }
runner.refMu.Lock() runner.refMu.Lock()
runner.refCount-- runner.refCount--
if runner.refCount <= 0 { if runner.refCount <= 0 {
if runner.sessionDuration <= 0 { if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model) slog.Debug("runner with zero duration has gone idle, expiring to unload", "modelPath", runner.modelPath)
if runner.expireTimer != nil { if runner.expireTimer != nil {
runner.expireTimer.Stop() runner.expireTimer.Stop()
runner.expireTimer = nil runner.expireTimer = nil
} }
s.expiredCh <- runner s.expiredCh <- runner
} else if runner.expireTimer == nil { } else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration) slog.Debug("runner with non-zero duration has gone idle, adding timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() { runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model) slog.Debug("timer expired, expiring to unload", "modelPath", runner.modelPath)
runner.refMu.Lock() runner.refMu.Lock()
defer runner.refMu.Unlock() defer runner.refMu.Unlock()
if runner.expireTimer != nil { if runner.expireTimer != nil {
@ -244,19 +244,21 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
} }
s.expiredCh <- runner s.expiredCh <- runner
}) })
runner.expiresAt = time.Now().Add(runner.sessionDuration)
} else { } else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration) slog.Debug("runner with non-zero duration has gone idle, resetting timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration) runner.expireTimer.Reset(runner.sessionDuration)
runner.expiresAt = time.Now().Add(runner.sessionDuration)
} }
} }
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount) slog.Debug("after processing request finished event", "modelPath", runner.modelPath, "refCount", runner.refCount)
runner.refMu.Unlock() runner.refMu.Unlock()
case runner := <-s.expiredCh: case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model) slog.Debug("runner expired event received", "modelPath", runner.modelPath)
runner.refMu.Lock() runner.refMu.Lock()
if runner.refCount > 0 { if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners // Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount) slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
go func(runner *runnerRef) { go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes // We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event // So queue up another expired event
@ -268,16 +270,16 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
} }
s.loadedMu.Lock() s.loadedMu.Lock()
slog.Debug("got lock to unload", "model", runner.model) slog.Debug("got lock to unload", "modelPath", runner.modelPath)
finished := runner.waitForVRAMRecovery() finished := runner.waitForVRAMRecovery()
runner.unload() runner.unload()
delete(s.loaded, runner.model) delete(s.loaded, runner.modelPath)
s.loadedMu.Unlock() s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model) slog.Debug("runner released", "modelPath", runner.modelPath)
runner.refMu.Unlock() runner.refMu.Unlock()
<-finished <-finished
slog.Debug("sending an unloaded event", "model", runner.model) slog.Debug("sending an unloaded event", "modelPath", runner.modelPath)
s.unloadedCh <- struct{}{} s.unloadedCh <- struct{}{}
} }
} }
@ -316,18 +318,20 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
req.errCh <- err req.errCh <- err
return return
} }
runner := &runnerRef{} runner := &runnerRef{
runner.model = req.model.ModelPath model: req.model,
runner.adapters = req.model.AdapterPaths modelPath: req.model.ModelPath,
runner.projectors = req.model.ProjectorPaths llama: llama,
runner.llama = llama Options: &req.opts,
runner.Options = &req.opts sessionDuration: req.sessionDuration,
runner.sessionDuration = req.sessionDuration gpus: gpus,
runner.gpus = gpus estimatedVRAM: llama.EstimatedVRAM(),
runner.estimatedVRAM = llama.EstimatedVRAM() estimatedTotal: llama.EstimatedTotal(),
runner.loading = true loading: true,
runner.refCount = 1 refCount: 1,
}
runner.refMu.Lock() runner.refMu.Lock()
s.loadedMu.Lock() s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded)) slog.Info("loaded runners", "count", len(s.loaded))
@ -339,7 +343,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
slog.Error("error loading llama server", "error", err) slog.Error("error loading llama server", "error", err)
runner.refCount-- runner.refCount--
req.errCh <- err req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model) slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
s.expiredCh <- runner s.expiredCh <- runner
return return
} }
@ -408,17 +412,18 @@ type runnerRef struct {
refCount uint // prevent unloading if > 0 refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner // unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer llama llm.LlamaServer
loading bool // True only during initial load, then false forever loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64 estimatedVRAM uint64
estimatedTotal uint64
sessionDuration time.Duration sessionDuration time.Duration
expireTimer *time.Timer expireTimer *time.Timer
expiresAt time.Time
model string model *Model
adapters []string modelPath string
projectors []string
*api.Options *api.Options
} }
@ -431,9 +436,8 @@ func (runner *runnerRef) unload() {
if runner.llama != nil { if runner.llama != nil {
runner.llama.Close() runner.llama.Close()
} }
runner.model = nil
runner.llama = nil runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil runner.Options = nil
runner.gpus = nil runner.gpus = nil
} }
@ -462,8 +466,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
ctx, cancel := context.WithTimeout(ctx, timeout) ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil { runner.llama.Ping(ctx) != nil {
return true return true

View file

@ -164,7 +164,8 @@ func TestRequests(t *testing.T) {
// simple reload of same model // simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20) scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml scenario2a.ggml = scenario1a.ggml
// Multiple loaded models // Multiple loaded models
@ -496,10 +497,9 @@ func TestNeedsReload(t *testing.T) {
llm := &mockLlm{} llm := &mockLlm{}
do := api.DefaultOptions() do := api.DefaultOptions()
runner := &runnerRef{ runner := &runnerRef{
adapters: []string{"adapter1"}, model: &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
projectors: []string{"projector1"}, Options: &do,
Options: &do, llama: llm,
llama: llm,
} }
req := &LlmRequest{ req := &LlmRequest{
model: &Model{ model: &Model{
@ -510,10 +510,10 @@ func TestNeedsReload(t *testing.T) {
} }
resp := runner.needsReload(ctx, req) resp := runner.needsReload(ctx, req)
require.True(t, resp) require.True(t, resp)
req.model.AdapterPaths = runner.adapters req.model.AdapterPaths = runner.model.AdapterPaths
resp = runner.needsReload(ctx, req) resp = runner.needsReload(ctx, req)
require.True(t, resp) require.True(t, resp)
req.model.ProjectorPaths = runner.projectors req.model.ProjectorPaths = runner.model.ProjectorPaths
runner.loading = true runner.loading = true
req.opts.NumBatch = 1234 req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req) resp = runner.needsReload(ctx, req)
@ -558,11 +558,11 @@ func TestUnloadAllRunners(t *testing.T) {
func TestUnload(t *testing.T) { func TestUnload(t *testing.T) {
llm1 := &mockLlm{} llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1} r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}} r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
r1.unload() r1.unload()
require.True(t, llm1.closeCalled) require.True(t, llm1.closeCalled)
r2.unload() r2.unload()
require.Nil(t, r2.adapters) require.Nil(t, r2.model)
} }
type mockLlm struct { type mockLlm struct {
@ -578,6 +578,7 @@ type mockLlm struct {
closeResp error closeResp error
closeCalled bool closeCalled bool
estimatedVRAM uint64 estimatedVRAM uint64
estimatedTotal uint64
} }
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp } func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
@ -598,4 +599,5 @@ func (s *mockLlm) Close() error {
s.closeCalled = true s.closeCalled = true
return s.closeResp return s.closeResp
} }
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM } func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }