cache loaded model
This commit is contained in:
parent
81f75696e2
commit
528bafa585
4 changed files with 30 additions and 42 deletions
|
@ -30,9 +30,6 @@ func (e StatusError) Error() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
SessionID int64 `json:"session_id"`
|
|
||||||
SessionDuration Duration `json:"session_duration,omitempty"`
|
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
@ -86,9 +83,6 @@ type ListResponseModel struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
SessionID int64 `json:"session_id"`
|
|
||||||
SessionExpiresAt time.Time `json:"session_expires_at"`
|
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Response string `json:"response,omitempty"`
|
Response string `json:"response,omitempty"`
|
||||||
|
|
|
@ -260,12 +260,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
generateContext = []int{}
|
generateContext = []int{}
|
||||||
}
|
}
|
||||||
|
|
||||||
generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
|
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
||||||
if !ok {
|
|
||||||
generateSession = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
|
|
||||||
fn := func(response api.GenerateResponse) error {
|
fn := func(response api.GenerateResponse) error {
|
||||||
if !spinner.IsFinished() {
|
if !spinner.IsFinished() {
|
||||||
spinner.Finish()
|
spinner.Finish()
|
||||||
|
@ -295,7 +290,6 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
|
|
||||||
ctx := cmd.Context()
|
ctx := cmd.Context()
|
||||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||||
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
|
|
||||||
cmd.SetContext(ctx)
|
cmd.SetContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ type Model struct {
|
||||||
ModelPath string
|
ModelPath string
|
||||||
Template string
|
Template string
|
||||||
System string
|
System string
|
||||||
|
Digest string
|
||||||
Options api.Options
|
Options api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +136,7 @@ func GetModel(name string) (*Model, error) {
|
||||||
|
|
||||||
model := &Model{
|
model := &Model{
|
||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
|
Digest: manifest.Config.Digest,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range manifest.Layers {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -22,19 +23,21 @@ import (
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
var activeSession struct {
|
var loaded struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
id int64
|
|
||||||
llm *llama.LLM
|
llm *llama.LLM
|
||||||
|
|
||||||
expireAt time.Time
|
expireAt time.Time
|
||||||
expireTimer *time.Timer
|
expireTimer *time.Timer
|
||||||
|
|
||||||
|
digest string
|
||||||
|
options api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
activeSession.mu.Lock()
|
loaded.mu.Lock()
|
||||||
defer activeSession.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
@ -50,10 +53,10 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.SessionID == 0 || req.SessionID != activeSession.id {
|
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, req.Options) {
|
||||||
if activeSession.llm != nil {
|
if loaded.llm != nil {
|
||||||
activeSession.llm.Close()
|
loaded.llm.Close()
|
||||||
activeSession.llm = nil
|
loaded.llm = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
|
@ -73,33 +76,31 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
activeSession.id = time.Now().UnixNano()
|
loaded.llm = llm
|
||||||
activeSession.llm = llm
|
loaded.digest = model.Digest
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := req.SessionDuration
|
sessionDuration := 5 * time.Minute
|
||||||
sessionID := activeSession.id
|
|
||||||
|
|
||||||
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
if activeSession.expireTimer == nil {
|
if loaded.expireTimer == nil {
|
||||||
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
|
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||||
activeSession.mu.Lock()
|
loaded.mu.Lock()
|
||||||
defer activeSession.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
if sessionID != activeSession.id {
|
if time.Now().Before(loaded.expireAt) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().Before(activeSession.expireAt) {
|
if loaded.llm == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
activeSession.llm.Close()
|
loaded.llm.Close()
|
||||||
activeSession.llm = nil
|
loaded.llm = nil
|
||||||
activeSession.id = 0
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
@ -113,13 +114,11 @@ func GenerateHandler(c *gin.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r api.GenerateResponse) {
|
||||||
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
|
||||||
r.Model = req.Model
|
r.Model = req.Model
|
||||||
r.CreatedAt = time.Now().UTC()
|
r.CreatedAt = time.Now().UTC()
|
||||||
r.SessionID = activeSession.id
|
|
||||||
r.SessionExpiresAt = activeSession.expireAt.UTC()
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
r.TotalDuration = time.Since(checkpointStart)
|
r.TotalDuration = time.Since(checkpointStart)
|
||||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
@ -128,8 +127,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
|
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||||
log.Printf("llm.Predict failed with %s", err)
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
Loading…
Reference in a new issue