cache loaded model

This commit is contained in:
Jeffrey Morgan 2023-07-31 21:35:18 -04:00
parent 81f75696e2
commit 528bafa585
4 changed files with 30 additions and 42 deletions

View file

@ -30,9 +30,6 @@ func (e StatusError) Error() string {
} }
type GenerateRequest struct { type GenerateRequest struct {
SessionID int64 `json:"session_id"`
SessionDuration Duration `json:"session_duration,omitempty"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
@ -86,9 +83,6 @@ type ListResponseModel struct {
} }
type GenerateResponse struct { type GenerateResponse struct {
SessionID int64 `json:"session_id"`
SessionExpiresAt time.Time `json:"session_expires_at"`
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"` Response string `json:"response,omitempty"`

View file

@ -260,12 +260,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
generateContext = []int{} generateContext = []int{}
} }
generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64) request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
if !ok {
generateSession = 0
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
@ -295,7 +290,6 @@ func generate(cmd *cobra.Command, model, prompt string) error {
ctx := cmd.Context() ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
cmd.SetContext(ctx) cmd.SetContext(ctx)
} }

View file

@ -32,6 +32,7 @@ type Model struct {
ModelPath string ModelPath string
Template string Template string
System string System string
Digest string
Options api.Options Options api.Options
} }
@ -135,6 +136,7 @@ func GetModel(name string) (*Model, error) {
model := &Model{ model := &Model{
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
Digest: manifest.Config.Digest,
} }
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {

View file

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -22,19 +23,21 @@ import (
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
) )
var activeSession struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
id int64
llm *llama.LLM llm *llama.LLM
expireAt time.Time expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
digest string
options api.Options
} }
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
activeSession.mu.Lock() loaded.mu.Lock()
defer activeSession.mu.Unlock() defer loaded.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
@ -50,10 +53,10 @@ func GenerateHandler(c *gin.Context) {
return return
} }
if req.SessionID == 0 || req.SessionID != activeSession.id { if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, req.Options) {
if activeSession.llm != nil { if loaded.llm != nil {
activeSession.llm.Close() loaded.llm.Close()
activeSession.llm = nil loaded.llm = nil
} }
opts := api.DefaultOptions() opts := api.DefaultOptions()
@ -73,33 +76,31 @@ func GenerateHandler(c *gin.Context) {
return return
} }
activeSession.id = time.Now().UnixNano() loaded.llm = llm
activeSession.llm = llm loaded.digest = model.Digest
} }
sessionDuration := req.SessionDuration sessionDuration := 5 * time.Minute
sessionID := activeSession.id
activeSession.expireAt = time.Now().Add(sessionDuration.Duration) loaded.expireAt = time.Now().Add(sessionDuration)
if activeSession.expireTimer == nil { if loaded.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() { loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
activeSession.mu.Lock() loaded.mu.Lock()
defer activeSession.mu.Unlock() defer loaded.mu.Unlock()
if sessionID != activeSession.id { if time.Now().Before(loaded.expireAt) {
return return
} }
if time.Now().Before(activeSession.expireAt) { if loaded.llm == nil {
return return
} }
activeSession.llm.Close() loaded.llm.Close()
activeSession.llm = nil loaded.llm = nil
activeSession.id = 0
}) })
} }
activeSession.expireTimer.Reset(sessionDuration.Duration) loaded.expireTimer.Reset(sessionDuration)
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
@ -113,13 +114,11 @@ func GenerateHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(r api.GenerateResponse) { fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration) loaded.expireAt = time.Now().Add(sessionDuration)
activeSession.expireTimer.Reset(sessionDuration.Duration) loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model r.Model = req.Model
r.CreatedAt = time.Now().UTC() r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done { if r.Done {
r.TotalDuration = time.Since(checkpointStart) r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@ -128,8 +127,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r ch <- r
} }
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil { if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
log.Printf("llm.Predict failed with %s", err)
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()