Relay default values to llama runner (#672)

* include seed in params for llama.cpp server and remove empty filter for temp

* relay default predict options to llama.cpp

- reorganize options to match predict request for readability

* omit empty stop

---------

Co-authored-by: hallh <hallh@users.noreply.github.com>
This commit is contained in:
Bruce MacDonald 2023-10-02 14:53:16 -04:00 committed by GitHub
parent 99d5161e8a
commit 1fbf3585d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 44 deletions

View file

@ -280,38 +280,38 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
func DefaultOptions() Options { func DefaultOptions() Options {
return Options{ return Options{
Seed: -1, // options set on request to runner
NumPredict: -1,
UseNUMA: false, NumKeep: -1,
NumCtx: 2048,
NumKeep: -1,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1,
LowVRAM: false,
F16KV: true,
UseMMap: true,
UseMLock: false,
RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0,
EmbeddingOnly: true,
RepeatLastN: 64,
RepeatPenalty: 1.1,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
Temperature: 0.8, Temperature: 0.8,
TopK: 40, TopK: 40,
TopP: 0.9, TopP: 0.9,
TFSZ: 1.0, TFSZ: 1.0,
TypicalP: 1.0, TypicalP: 1.0,
RepeatLastN: 64,
RepeatPenalty: 1.1,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Mirostat: 0, Mirostat: 0,
MirostatTau: 5.0, MirostatTau: 5.0,
MirostatEta: 0.1, MirostatEta: 0.1,
PenalizeNewline: true, PenalizeNewline: true,
Seed: -1,
NumThread: 0, // let the runtime decide // options set when the model is loaded
NumCtx: 2048,
RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1,
NumThread: 0, // let the runtime decide
LowVRAM: false,
F16KV: true,
UseMLock: false,
UseMMap: true,
UseNUMA: false,
EmbeddingOnly: true,
} }
} }

View file

@ -417,28 +417,25 @@ type Prediction struct {
} }
type PredictRequest struct { type PredictRequest struct {
Stream bool `json:"stream"` Prompt string `json:"prompt"`
NPredict int `json:"n_predict,omitempty"` Stream bool `json:"stream"`
TopK int `json:"top_k,omitempty"` NPredict int `json:"n_predict"`
TopP float32 `json:"top_p,omitempty"` NKeep int `json:"n_keep"`
TfsZ float32 `json:"tfs_z,omitempty"` Temperature float32 `json:"temperature"`
TypicalP float32 `json:"typical_p,omitempty"` TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n,omitempty"` TopP float32 `json:"top_p"`
Temperature float32 `json:"temperature,omitempty"` TfsZ float32 `json:"tfs_z"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"` TypicalP float32 `json:"typical_p"`
PresencePenalty float32 `json:"presence_penalty,omitempty"` RepeatLastN int `json:"repeat_last_n"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` RepeatPenalty float32 `json:"repeat_penalty"`
Mirostat int `json:"mirostat,omitempty"` PresencePenalty float32 `json:"presence_penalty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"` Mirostat int `json:"mirostat"`
PenalizeNl bool `json:"penalize_nl,omitempty"` MirostatTau float32 `json:"mirostat_tau"`
NKeep int `json:"n_keep,omitempty"` MirostatEta float32 `json:"mirostat_eta"`
Seed int `json:"seed,omitempty"` PenalizeNl bool `json:"penalize_nl"`
Prompt string `json:"prompt,omitempty"` Seed int `json:"seed"`
NProbs int `json:"n_probs,omitempty"` Stop []string `json:"stop,omitempty"`
LogitBias map[int]float32 `json:"logit_bias,omitempty"`
IgnoreEos bool `json:"ignore_eos,omitempty"`
Stop []string `json:"stop,omitempty"`
} }
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
@ -470,8 +467,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
MirostatTau: llm.MirostatTau, MirostatTau: llm.MirostatTau,
MirostatEta: llm.MirostatEta, MirostatEta: llm.MirostatEta,
PenalizeNl: llm.PenalizeNewline, PenalizeNl: llm.PenalizeNewline,
Seed: llm.Seed,
Stop: llm.Stop, Stop: llm.Stop,
} }
data, err := json.Marshal(predReq) data, err := json.Marshal(predReq)
if err != nil { if err != nil {
return fmt.Errorf("error marshaling data: %v", err) return fmt.Errorf("error marshaling data: %v", err)