do not reload the running llm when runtime params change (#840)

- only reload the running llm if the model has changed, or the options for loading the running model have changed
- rename loaded llm to runner to differentiate from loaded model image
- remove logic which keeps the first system prompt in the generation context
This commit is contained in:
Bruce MacDonald 2023-10-19 10:39:58 -04:00 committed by GitHub
parent 235e43d7f6
commit fe6f3b48f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 86 deletions

View file

@ -161,15 +161,10 @@ func (r *GenerateResponse) Summary() {
}
}
type Options struct {
Seed int `json:"seed,omitempty"`
// Backend options
// Runner options which must be set when the model is loaded into memory
type Runner struct {
UseNUMA bool `json:"numa,omitempty"`
// Model options
NumCtx int `json:"num_ctx,omitempty"`
NumKeep int `json:"num_keep,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
@ -183,8 +178,15 @@ type Options struct {
EmbeddingOnly bool `json:"embedding_only,omitempty"`
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
// Predict options
type Options struct {
Runner
// Predict options used at runtime
NumKeep int `json:"num_keep,omitempty"`
Seed int `json:"seed,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
@ -200,8 +202,6 @@ type Options struct {
MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
Stop []string `json:"stop,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
var ErrInvalidOpts = fmt.Errorf("invalid options")
@ -309,6 +309,7 @@ func DefaultOptions() Options {
PenalizeNewline: true,
Seed: -1,
Runner: Runner{
// options set when the model is loaded
NumCtx: 2048,
RopeFrequencyBase: 10000.0,
@ -323,6 +324,7 @@ func DefaultOptions() Options {
UseMMap: true,
UseNUMA: false,
EmbeddingOnly: true,
},
}
}

View file

@ -45,7 +45,6 @@ type Model struct {
System string
License []string
Digest string
ConfigDigest string
Options map[string]interface{}
}
@ -169,7 +168,6 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
ConfigDigest: manifest.Config.Digest,
Template: "{{ .Prompt }}",
License: []string{},
}

View file

@ -46,13 +46,13 @@ func init() {
var loaded struct {
mu sync.Mutex
llm llm.LLM
runner llm.LLM
expireAt time.Time
expireTimer *time.Timer
digest string
options api.Options
*Model
*api.Options
}
var defaultSessionDuration = 5 * time.Minute
@ -70,59 +70,39 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
}
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.llm != nil {
if err := loaded.llm.Ping(ctx); err != nil {
if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil {
log.Print("loaded llm process not responding, closing now")
// the subprocess is no longer running, so close it
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
}
}
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
if loaded.llm != nil {
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
if needLoad {
if loaded.runner != nil {
log.Println("changing loaded model")
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
}
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
if err != nil {
return err
}
// set cache values before modifying opts
loaded.llm = llmModel
loaded.digest = model.Digest
loaded.options = opts
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil {
return err
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil {
return err
}
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
if err != nil {
return err
}
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
if err != nil {
return err
}
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
llmModel.SetOptions(opts)
}
loaded.Model = model
loaded.runner = llmRunner
loaded.Options = &opts
}
loaded.expireAt = time.Now().Add(sessionDuration)
@ -136,13 +116,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
return
}
if loaded.llm == nil {
return
if loaded.runner != nil {
loaded.runner.Close()
}
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
})
}
@ -215,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else {
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}
@ -263,12 +243,12 @@ func EmbeddingHandler(c *gin.Context) {
return
}
if !loaded.options.EmbeddingOnly {
if !loaded.Options.EmbeddingOnly {
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
return
}
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -599,8 +579,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.llm != nil {
loaded.llm.Close()
if loaded.runner != nil {
loaded.runner.Close()
}
os.RemoveAll(workDir)
os.Exit(0)