add session expiration

This commit is contained in:
Michael Yang 2023-07-19 15:00:28 -07:00
parent 3003fc03fc
commit f62a882760
3 changed files with 100 additions and 20 deletions

View file

@ -1,7 +1,9 @@
package api
import (
"encoding/json"
"fmt"
"math"
"os"
"runtime"
"time"
@ -28,10 +30,12 @@ func (e StatusError) Error() string {
}
type GenerateRequest struct {
SessionID int64 `json:"session_id"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
SessionID int64 `json:"session_id"`
SessionDuration Duration `json:"session_duration,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
Options `json:"options"`
}
@ -82,7 +86,9 @@ type ListResponseModel struct {
}
type GenerateResponse struct {
SessionID int64 `json:"session_id"`
SessionID int64 `json:"session_id"`
SessionExpiresAt time.Time `json:"session_expires_at"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"`
@ -195,3 +201,32 @@ func DefaultOptions() Options {
NumThread: runtime.NumCPU(),
}
}
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any
if err := json.Unmarshal(b, &v); err != nil {
return err
}
d.Duration = 5 * time.Minute
switch t := v.(type) {
case float64:
if t < 0 {
t = math.MaxFloat64
}
d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
}
return nil
}

View file

@ -92,6 +92,7 @@ import (
"log"
"os"
"strings"
"sync"
"unicode/utf8"
"unsafe"
@ -107,6 +108,9 @@ type LLM struct {
embd []C.llama_token
cursor int
mu sync.Mutex
gc bool
api.Options
}
@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) {
}
func (llm *LLM) Close() {
llm.gc = true
llm.mu.Lock()
defer llm.mu.Unlock()
defer C.llama_free_model(llm.model)
defer C.llama_free(llm.ctx)
@ -163,6 +172,9 @@ func (llm *LLM) Close() {
}
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
llm.mu.Lock()
defer llm.mu.Unlock()
C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx))
@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
break
} else if err != nil {
return err
} else if llm.gc {
return io.EOF
}
b.WriteString(llm.detokenize(token))

View file

@ -22,16 +22,19 @@ import (
"github.com/jmorganca/ollama/llama"
)
var mu sync.Mutex
var activeSession struct {
ID int64
*llama.LLM
mu sync.Mutex
id int64
llm *llama.LLM
expireAt time.Time
expireTimer *time.Timer
}
func GenerateHandler(c *gin.Context) {
mu.Lock()
defer mu.Unlock()
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
checkpointStart := time.Now()
@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
return
}
if req.SessionID == 0 || req.SessionID != activeSession.ID {
if activeSession.LLM != nil {
activeSession.Close()
activeSession.LLM = nil
if req.SessionID == 0 || req.SessionID != activeSession.id {
if activeSession.llm != nil {
activeSession.llm.Close()
activeSession.llm = nil
}
opts := api.DefaultOptions()
@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) {
return
}
activeSession.ID = time.Now().UnixNano()
activeSession.LLM = llm
activeSession.id = time.Now().UnixNano()
activeSession.llm = llm
}
sessionDuration := req.SessionDuration
sessionID := activeSession.id
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
if activeSession.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
if sessionID != activeSession.id {
return
}
if time.Now().Before(activeSession.expireAt) {
return
}
activeSession.llm.Close()
activeSession.llm = nil
activeSession.id = 0
})
}
activeSession.expireTimer.Reset(sessionDuration.Duration)
checkpointLoaded := time.Now()
prompt, err := model.Prompt(req)
@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
activeSession.expireTimer.Reset(sessionDuration.Duration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.ID
r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return
}
c.JSON(http.StatusOK, api.ListResponse{models})
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {