diff --git a/api/types.go b/api/types.go index 39167ffd..8ec37388 100644 --- a/api/types.go +++ b/api/types.go @@ -31,6 +31,92 @@ type PullProgress struct { type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` + + ModelOptions `json:"model_opts"` + PredictOptions `json:"predict_opts"` +} + +type ModelOptions struct { + ContextSize int `json:"context_size"` + Seed int `json:"seed"` + NBatch int `json:"n_batch"` + F16Memory bool `json:"memory_f16"` + MLock bool `json:"mlock"` + MMap bool `json:"mmap"` + VocabOnly bool `json:"vocab_only"` + LowVRAM bool `json:"low_vram"` + Embeddings bool `json:"embeddings"` + NUMA bool `json:"numa"` + NGPULayers int `json:"gpu_layers"` + MainGPU string `json:"main_gpu"` + TensorSplit string `json:"tensor_split"` +} + +type PredictOptions struct { + Seed int `json:"seed"` + Threads int `json:"threads"` + Tokens int `json:"tokens"` + TopK int `json:"top_k"` + Repeat int `json:"repeat"` + Batch int `json:"batch"` + NKeep int `json:"nkeep"` + TopP float64 `json:"top_p"` + Temperature float64 `json:"temp"` + Penalty float64 `json:"penalty"` + F16KV bool + DebugMode bool + StopPrompts []string + IgnoreEOS bool `json:"ignore_eos"` + + TailFreeSamplingZ float64 `json:"tfs_z"` + TypicalP float64 `json:"typical_p"` + FrequencyPenalty float64 `json:"freq_penalty"` + PresencePenalty float64 `json:"pres_penalty"` + Mirostat int `json:"mirostat"` + MirostatETA float64 `json:"mirostat_lr"` + MirostatTAU float64 `json:"mirostat_ent"` + PenalizeNL bool `json:"penalize_nl"` + LogitBias string `json:"logit_bias"` + + PathPromptCache string + MLock bool `json:"mlock"` + MMap bool `json:"mmap"` + PromptCacheAll bool + PromptCacheRO bool + MainGPU string + TensorSplit string +} + +var DefaultModelOptions ModelOptions = ModelOptions{ + ContextSize: 128, + Seed: 0, + F16Memory: true, + MLock: false, + Embeddings: true, + MMap: true, + LowVRAM: false, +} + +var DefaultPredictOptions PredictOptions = PredictOptions{ + Seed: -1, + Threads: -1, + Tokens: 512, + Penalty: 1.1, + Repeat: 64, + Batch: 512, + NKeep: 64, + TopK: 90, + TopP: 0.86, + TailFreeSamplingZ: 1.0, + TypicalP: 1.0, + Temperature: 0.8, + FrequencyPenalty: 0.0, + PresencePenalty: 0.0, + Mirostat: 0, + MirostatTAU: 5.0, + MirostatETA: 0.1, + MMap: true, + StopPrompts: []string{"llama"}, } type GenerateResponse struct { diff --git a/llama/llama.go b/llama/llama.go index f4b3aae2..ea315ef6 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -42,9 +42,7 @@ type LLama struct { contextSize int } -func New(model string, opts ...ModelOption) (*LLama, error) { - mo := NewModelOptions(opts...) - +func New(model string, mo ModelOptions) (*LLama, error) { modelPath := C.CString(model) defer C.free(unsafe.Pointer(modelPath)) @@ -108,9 +106,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { return nil } -func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { - po := NewPredictOptions(opts...) - +func (l *LLama) Predict(text string, po PredictOptions) (string, error) { if po.TokenCallback != nil { setCallback(l.ctx, po.TokenCallback) } diff --git a/server/routes.go b/server/routes.go index 4831b7ad..21a23785 100644 --- a/server/routes.go +++ b/server/routes.go @@ -26,12 +26,9 @@ var templatesFS embed.FS var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) func generate(c *gin.Context) { - // TODO: these should be request parameters - gpulayers := 1 - tokens := 512 - threads := runtime.NumCPU() - var req api.GenerateRequest + req.ModelOptions = api.DefaultModelOptions + req.PredictOptions = api.DefaultPredictOptions if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return @@ -41,7 +38,10 @@ func generate(c *gin.Context) { req.Model = remoteModel.FullName() } - model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) + modelOpts := getModelOpts(req) + modelOpts.NGPULayers = 1 // hard-code this for now + + model, err := llama.New(req.Model, modelOpts) if err != nil { fmt.Println("Loading the model failed:", err.Error()) return @@ -65,13 +65,16 @@ func generate(c *gin.Context) { } ch := make(chan string) + model.SetTokenCallback(func(token string) bool { + ch <- token + return true + }) + + predictOpts := getPredictOpts(req) go func() { defer close(ch) - _, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool { - ch <- token - return true - }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) + _, err := model.Predict(req.Prompt, predictOpts) if err != nil { panic(err) } @@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } + +func getModelOpts(req api.GenerateRequest) llama.ModelOptions { + var opts llama.ModelOptions + opts.ContextSize = req.ModelOptions.ContextSize + opts.Seed = req.ModelOptions.Seed + opts.F16Memory = req.ModelOptions.F16Memory + opts.MLock = req.ModelOptions.MLock + opts.Embeddings = req.ModelOptions.Embeddings + opts.MMap = req.ModelOptions.MMap + opts.LowVRAM = req.ModelOptions.LowVRAM + + opts.NBatch = req.ModelOptions.NBatch + opts.VocabOnly = req.ModelOptions.VocabOnly + opts.NUMA = req.ModelOptions.NUMA + opts.NGPULayers = req.ModelOptions.NGPULayers + opts.MainGPU = req.ModelOptions.MainGPU + opts.TensorSplit = req.ModelOptions.TensorSplit + + return opts +} + +func getPredictOpts(req api.GenerateRequest) llama.PredictOptions { + var opts llama.PredictOptions + + if req.PredictOptions.Threads == -1 { + opts.Threads = runtime.NumCPU() + } else { + opts.Threads = req.PredictOptions.Threads + } + + opts.Seed = req.PredictOptions.Seed + opts.Tokens = req.PredictOptions.Tokens + opts.Penalty = req.PredictOptions.Penalty + opts.Repeat = req.PredictOptions.Repeat + opts.Batch = req.PredictOptions.Batch + opts.NKeep = req.PredictOptions.NKeep + opts.TopK = req.PredictOptions.TopK + opts.TopP = req.PredictOptions.TopP + opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ + opts.TypicalP = req.PredictOptions.TypicalP + opts.Temperature = req.PredictOptions.Temperature + opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty + opts.PresencePenalty = req.PredictOptions.PresencePenalty + opts.Mirostat = req.PredictOptions.Mirostat + opts.MirostatTAU = req.PredictOptions.MirostatTAU + opts.MirostatETA = req.PredictOptions.MirostatETA + opts.MMap = req.PredictOptions.MMap + + return opts +}