diff --git a/api/types.go b/api/types.go index 5f8c3891..24666462 100644 --- a/api/types.go +++ b/api/types.go @@ -1,7 +1,9 @@ package api import ( + "encoding/json" "fmt" + "math" "os" "runtime" "time" @@ -28,10 +30,12 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - SessionID int64 `json:"session_id"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Context []int `json:"context,omitempty"` + SessionID int64 `json:"session_id"` + SessionDuration Duration `json:"session_duration,omitempty"` + + Model string `json:"model"` + Prompt string `json:"prompt"` + Context []int `json:"context,omitempty"` Options `json:"options"` } @@ -82,7 +86,9 @@ type ListResponseModel struct { } type GenerateResponse struct { - SessionID int64 `json:"session_id"` + SessionID int64 `json:"session_id"` + SessionExpiresAt time.Time `json:"session_expires_at"` + Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` @@ -195,3 +201,32 @@ func DefaultOptions() Options { NumThread: runtime.NumCPU(), } } + +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalJSON(b []byte) (err error) { + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + d.Duration = 5 * time.Minute + + switch t := v.(type) { + case float64: + if t < 0 { + t = math.MaxFloat64 + } + + d.Duration = time.Duration(t) + case string: + d.Duration, err = time.ParseDuration(t) + if err != nil { + return err + } + } + + return nil +} diff --git a/llama/llama.go b/llama/llama.go index 9f5066f3..5919b4bd 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -92,6 +92,7 @@ import ( "log" "os" "strings" + "sync" "unicode/utf8" "unsafe" @@ -107,6 +108,9 @@ type LLM struct { embd []C.llama_token cursor int + mu sync.Mutex + gc bool + api.Options } @@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) { } func (llm *LLM) Close() { + llm.gc = true + + llm.mu.Lock() + defer llm.mu.Unlock() + defer C.llama_free_model(llm.model) defer C.llama_free(llm.ctx) @@ -163,6 +172,9 @@ func (llm *LLM) Close() { } func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { + llm.mu.Lock() + defer llm.mu.Unlock() + C.llama_reset_timings(llm.ctx) tokens := make([]C.llama_token, len(ctx)) @@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) break } else if err != nil { return err + } else if llm.gc { + return io.EOF } b.WriteString(llm.detokenize(token)) diff --git a/server/routes.go b/server/routes.go index c3f27ec8..cc9df958 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,16 +22,19 @@ import ( "github.com/jmorganca/ollama/llama" ) -var mu sync.Mutex - var activeSession struct { - ID int64 - *llama.LLM + mu sync.Mutex + + id int64 + llm *llama.LLM + + expireAt time.Time + expireTimer *time.Timer } func GenerateHandler(c *gin.Context) { - mu.Lock() - defer mu.Unlock() + activeSession.mu.Lock() + defer activeSession.mu.Unlock() checkpointStart := time.Now() @@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) { return } - if req.SessionID == 0 || req.SessionID != activeSession.ID { - if activeSession.LLM != nil { - activeSession.Close() - activeSession.LLM = nil + if req.SessionID == 0 || req.SessionID != activeSession.id { + if activeSession.llm != nil { + activeSession.llm.Close() + activeSession.llm = nil } opts := api.DefaultOptions() @@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) { return } - activeSession.ID = time.Now().UnixNano() - activeSession.LLM = llm + activeSession.id = time.Now().UnixNano() + activeSession.llm = llm } + sessionDuration := req.SessionDuration + sessionID := activeSession.id + + activeSession.expireAt = time.Now().Add(sessionDuration.Duration) + if activeSession.expireTimer == nil { + activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() { + activeSession.mu.Lock() + defer activeSession.mu.Unlock() + + if sessionID != activeSession.id { + return + } + + if time.Now().Before(activeSession.expireAt) { + return + } + + activeSession.llm.Close() + activeSession.llm = nil + activeSession.id = 0 + }) + } + activeSession.expireTimer.Reset(sessionDuration.Duration) + checkpointLoaded := time.Now() prompt, err := model.Prompt(req) @@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) { go func() { defer close(ch) fn := func(r api.GenerateResponse) { + activeSession.expireAt = time.Now().Add(sessionDuration.Duration) + activeSession.expireTimer.Reset(sessionDuration.Duration) + r.Model = req.Model r.CreatedAt = time.Now().UTC() - r.SessionID = activeSession.ID + r.SessionID = activeSession.id + r.SessionExpiresAt = activeSession.expireAt.UTC() if r.Done { r.TotalDuration = time.Since(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil { + if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, api.ListResponse{models}) + c.JSON(http.StatusOK, api.ListResponse{Models: models}) } func CopyModelHandler(c *gin.Context) {