add session expiration
This commit is contained in:
parent
3003fc03fc
commit
f62a882760
3 changed files with 100 additions and 20 deletions
35
api/types.go
35
api/types.go
|
@ -1,7 +1,9 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
@ -29,6 +31,8 @@ func (e StatusError) Error() string {
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
SessionID int64 `json:"session_id"`
|
SessionID int64 `json:"session_id"`
|
||||||
|
SessionDuration Duration `json:"session_duration,omitempty"`
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
@ -83,6 +87,8 @@ type ListResponseModel struct {
|
||||||
|
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
SessionID int64 `json:"session_id"`
|
SessionID int64 `json:"session_id"`
|
||||||
|
SessionExpiresAt time.Time `json:"session_expires_at"`
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Response string `json:"response,omitempty"`
|
Response string `json:"response,omitempty"`
|
||||||
|
@ -195,3 +201,32 @@ func DefaultOptions() Options {
|
||||||
NumThread: runtime.NumCPU(),
|
NumThread: runtime.NumCPU(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
|
var v any
|
||||||
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Duration = 5 * time.Minute
|
||||||
|
|
||||||
|
switch t := v.(type) {
|
||||||
|
case float64:
|
||||||
|
if t < 0 {
|
||||||
|
t = math.MaxFloat64
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Duration = time.Duration(t)
|
||||||
|
case string:
|
||||||
|
d.Duration, err = time.ParseDuration(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -92,6 +92,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
@ -107,6 +108,9 @@ type LLM struct {
|
||||||
embd []C.llama_token
|
embd []C.llama_token
|
||||||
cursor int
|
cursor int
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
gc bool
|
||||||
|
|
||||||
api.Options
|
api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Close() {
|
func (llm *LLM) Close() {
|
||||||
|
llm.gc = true
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
|
||||||
defer C.llama_free_model(llm.model)
|
defer C.llama_free_model(llm.model)
|
||||||
defer C.llama_free(llm.ctx)
|
defer C.llama_free(llm.ctx)
|
||||||
|
|
||||||
|
@ -163,6 +172,9 @@ func (llm *LLM) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
|
||||||
C.llama_reset_timings(llm.ctx)
|
C.llama_reset_timings(llm.ctx)
|
||||||
|
|
||||||
tokens := make([]C.llama_token, len(ctx))
|
tokens := make([]C.llama_token, len(ctx))
|
||||||
|
@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
|
} else if llm.gc {
|
||||||
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(llm.detokenize(token))
|
b.WriteString(llm.detokenize(token))
|
||||||
|
|
|
@ -22,16 +22,19 @@ import (
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
var mu sync.Mutex
|
|
||||||
|
|
||||||
var activeSession struct {
|
var activeSession struct {
|
||||||
ID int64
|
mu sync.Mutex
|
||||||
*llama.LLM
|
|
||||||
|
id int64
|
||||||
|
llm *llama.LLM
|
||||||
|
|
||||||
|
expireAt time.Time
|
||||||
|
expireTimer *time.Timer
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
mu.Lock()
|
activeSession.mu.Lock()
|
||||||
defer mu.Unlock()
|
defer activeSession.mu.Unlock()
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.SessionID == 0 || req.SessionID != activeSession.ID {
|
if req.SessionID == 0 || req.SessionID != activeSession.id {
|
||||||
if activeSession.LLM != nil {
|
if activeSession.llm != nil {
|
||||||
activeSession.Close()
|
activeSession.llm.Close()
|
||||||
activeSession.LLM = nil
|
activeSession.llm = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
|
@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
activeSession.ID = time.Now().UnixNano()
|
activeSession.id = time.Now().UnixNano()
|
||||||
activeSession.LLM = llm
|
activeSession.llm = llm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionDuration := req.SessionDuration
|
||||||
|
sessionID := activeSession.id
|
||||||
|
|
||||||
|
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||||
|
if activeSession.expireTimer == nil {
|
||||||
|
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
|
||||||
|
activeSession.mu.Lock()
|
||||||
|
defer activeSession.mu.Unlock()
|
||||||
|
|
||||||
|
if sessionID != activeSession.id {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Before(activeSession.expireAt) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
activeSession.llm.Close()
|
||||||
|
activeSession.llm = nil
|
||||||
|
activeSession.id = 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, err := model.Prompt(req)
|
prompt, err := model.Prompt(req)
|
||||||
|
@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r api.GenerateResponse) {
|
||||||
|
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||||
|
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||||
|
|
||||||
r.Model = req.Model
|
r.Model = req.Model
|
||||||
r.CreatedAt = time.Now().UTC()
|
r.CreatedAt = time.Now().UTC()
|
||||||
r.SessionID = activeSession.ID
|
r.SessionID = activeSession.id
|
||||||
|
r.SessionExpiresAt = activeSession.expireAt.UTC()
|
||||||
if r.Done {
|
if r.Done {
|
||||||
r.TotalDuration = time.Since(checkpointStart)
|
r.TotalDuration = time.Since(checkpointStart)
|
||||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
|
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ListResponse{models})
|
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyModelHandler(c *gin.Context) {
|
func CopyModelHandler(c *gin.Context) {
|
||||||
|
|
Loading…
Reference in a new issue