diff --git a/api/types.go b/api/types.go index 0b8603fe..eda7a992 100644 --- a/api/types.go +++ b/api/types.go @@ -30,9 +30,6 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - SessionID int64 `json:"session_id"` - SessionDuration Duration `json:"session_duration,omitempty"` - Model string `json:"model"` Prompt string `json:"prompt"` Context []int `json:"context,omitempty"` @@ -86,9 +83,6 @@ type ListResponseModel struct { } type GenerateResponse struct { - SessionID int64 `json:"session_id"` - SessionExpiresAt time.Time `json:"session_expires_at"` - Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 61658f87..486dd1c8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -260,12 +260,7 @@ func generate(cmd *cobra.Command, model, prompt string) error { generateContext = []int{} } - generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64) - if !ok { - generateSession = 0 - } - - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession} + request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} fn := func(response api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() @@ -295,7 +290,6 @@ func generate(cmd *cobra.Command, model, prompt string) error { ctx := cmd.Context() ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) - ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID) cmd.SetContext(ctx) } diff --git a/server/images.go b/server/images.go index 852c5cdc..17478fd2 100644 --- a/server/images.go +++ b/server/images.go @@ -32,6 +32,7 @@ type Model struct { ModelPath string Template string System string + Digest string Options api.Options } @@ -135,6 +136,7 @@ func GetModel(name string) (*Model, error) { model := &Model{ Name: mp.GetFullTagname(), + Digest: manifest.Config.Digest, } for _, layer := range manifest.Layers { diff --git a/server/routes.go b/server/routes.go index ec01378c..a7d25357 100644 --- a/server/routes.go +++ b/server/routes.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path/filepath" + "reflect" "strings" "sync" "time" @@ -22,19 +23,21 @@ import ( "github.com/jmorganca/ollama/llama" ) -var activeSession struct { +var loaded struct { mu sync.Mutex - id int64 llm *llama.LLM expireAt time.Time expireTimer *time.Timer + + digest string + options api.Options } func GenerateHandler(c *gin.Context) { - activeSession.mu.Lock() - defer activeSession.mu.Unlock() + loaded.mu.Lock() + defer loaded.mu.Unlock() checkpointStart := time.Now() @@ -50,10 +53,10 @@ func GenerateHandler(c *gin.Context) { return } - if req.SessionID == 0 || req.SessionID != activeSession.id { - if activeSession.llm != nil { - activeSession.llm.Close() - activeSession.llm = nil + if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, req.Options) { + if loaded.llm != nil { + loaded.llm.Close() + loaded.llm = nil } opts := api.DefaultOptions() @@ -73,33 +76,31 @@ func GenerateHandler(c *gin.Context) { return } - activeSession.id = time.Now().UnixNano() - activeSession.llm = llm + loaded.llm = llm + loaded.digest = model.Digest } - sessionDuration := req.SessionDuration - sessionID := activeSession.id + sessionDuration := 5 * time.Minute - activeSession.expireAt = time.Now().Add(sessionDuration.Duration) - if activeSession.expireTimer == nil { - activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() { - activeSession.mu.Lock() - defer activeSession.mu.Unlock() + loaded.expireAt = time.Now().Add(sessionDuration) + if loaded.expireTimer == nil { + loaded.expireTimer = time.AfterFunc(sessionDuration, func() { + loaded.mu.Lock() + defer loaded.mu.Unlock() - if sessionID != activeSession.id { + if time.Now().Before(loaded.expireAt) { return } - if time.Now().Before(activeSession.expireAt) { + if loaded.llm == nil { return } - activeSession.llm.Close() - activeSession.llm = nil - activeSession.id = 0 + loaded.llm.Close() + loaded.llm = nil }) } - activeSession.expireTimer.Reset(sessionDuration.Duration) + loaded.expireTimer.Reset(sessionDuration) checkpointLoaded := time.Now() @@ -113,13 +114,11 @@ func GenerateHandler(c *gin.Context) { go func() { defer close(ch) fn := func(r api.GenerateResponse) { - activeSession.expireAt = time.Now().Add(sessionDuration.Duration) - activeSession.expireTimer.Reset(sessionDuration.Duration) + loaded.expireAt = time.Now().Add(sessionDuration) + loaded.expireTimer.Reset(sessionDuration) r.Model = req.Model r.CreatedAt = time.Now().UTC() - r.SessionID = activeSession.id - r.SessionExpiresAt = activeSession.expireAt.UTC() if r.Done { r.TotalDuration = time.Since(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -128,8 +127,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil { - log.Printf("llm.Predict failed with %s", err) + if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()