add session expiration

This commit is contained in:
Michael Yang 2023-07-19 15:00:28 -07:00
parent 3003fc03fc
commit f62a882760
3 changed files with 100 additions and 20 deletions

View file

@ -1,7 +1,9 @@
package api package api
import ( import (
"encoding/json"
"fmt" "fmt"
"math"
"os" "os"
"runtime" "runtime"
"time" "time"
@ -28,10 +30,12 @@ func (e StatusError) Error() string {
} }
type GenerateRequest struct { type GenerateRequest struct {
SessionID int64 `json:"session_id"` SessionID int64 `json:"session_id"`
Model string `json:"model"` SessionDuration Duration `json:"session_duration,omitempty"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"` Model string `json:"model"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
Options `json:"options"` Options `json:"options"`
} }
@ -82,7 +86,9 @@ type ListResponseModel struct {
} }
type GenerateResponse struct { type GenerateResponse struct {
SessionID int64 `json:"session_id"` SessionID int64 `json:"session_id"`
SessionExpiresAt time.Time `json:"session_expires_at"`
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"` Response string `json:"response,omitempty"`
@ -195,3 +201,32 @@ func DefaultOptions() Options {
NumThread: runtime.NumCPU(), NumThread: runtime.NumCPU(),
} }
} }
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any
if err := json.Unmarshal(b, &v); err != nil {
return err
}
d.Duration = 5 * time.Minute
switch t := v.(type) {
case float64:
if t < 0 {
t = math.MaxFloat64
}
d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
}
return nil
}

View file

@ -92,6 +92,7 @@ import (
"log" "log"
"os" "os"
"strings" "strings"
"sync"
"unicode/utf8" "unicode/utf8"
"unsafe" "unsafe"
@ -107,6 +108,9 @@ type LLM struct {
embd []C.llama_token embd []C.llama_token
cursor int cursor int
mu sync.Mutex
gc bool
api.Options api.Options
} }
@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) {
} }
func (llm *LLM) Close() { func (llm *LLM) Close() {
llm.gc = true
llm.mu.Lock()
defer llm.mu.Unlock()
defer C.llama_free_model(llm.model) defer C.llama_free_model(llm.model)
defer C.llama_free(llm.ctx) defer C.llama_free(llm.ctx)
@ -163,6 +172,9 @@ func (llm *LLM) Close() {
} }
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
llm.mu.Lock()
defer llm.mu.Unlock()
C.llama_reset_timings(llm.ctx) C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx)) tokens := make([]C.llama_token, len(ctx))
@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
break break
} else if err != nil { } else if err != nil {
return err return err
} else if llm.gc {
return io.EOF
} }
b.WriteString(llm.detokenize(token)) b.WriteString(llm.detokenize(token))

View file

@ -22,16 +22,19 @@ import (
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
) )
var mu sync.Mutex
var activeSession struct { var activeSession struct {
ID int64 mu sync.Mutex
*llama.LLM
id int64
llm *llama.LLM
expireAt time.Time
expireTimer *time.Timer
} }
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
mu.Lock() activeSession.mu.Lock()
defer mu.Unlock() defer activeSession.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
return return
} }
if req.SessionID == 0 || req.SessionID != activeSession.ID { if req.SessionID == 0 || req.SessionID != activeSession.id {
if activeSession.LLM != nil { if activeSession.llm != nil {
activeSession.Close() activeSession.llm.Close()
activeSession.LLM = nil activeSession.llm = nil
} }
opts := api.DefaultOptions() opts := api.DefaultOptions()
@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) {
return return
} }
activeSession.ID = time.Now().UnixNano() activeSession.id = time.Now().UnixNano()
activeSession.LLM = llm activeSession.llm = llm
} }
sessionDuration := req.SessionDuration
sessionID := activeSession.id
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
if activeSession.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
if sessionID != activeSession.id {
return
}
if time.Now().Before(activeSession.expireAt) {
return
}
activeSession.llm.Close()
activeSession.llm = nil
activeSession.id = 0
})
}
activeSession.expireTimer.Reset(sessionDuration.Duration)
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt, err := model.Prompt(req) prompt, err := model.Prompt(req)
@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(r api.GenerateResponse) { fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
activeSession.expireTimer.Reset(sessionDuration.Duration)
r.Model = req.Model r.Model = req.Model
r.CreatedAt = time.Now().UTC() r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.ID r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done { if r.Done {
r.TotalDuration = time.Since(checkpointStart) r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r ch <- r
} }
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil { if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, api.ListResponse{models}) c.JSON(http.StatusOK, api.ListResponse{Models: models})
} }
func CopyModelHandler(c *gin.Context) { func CopyModelHandler(c *gin.Context) {