Merge pull request #52 from jmorganca/go-opts

pass model and predict options
This commit is contained in:
Patrick Devine 2023-07-07 10:59:11 -07:00 committed by GitHub
commit 3d73ad0c56
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 16 deletions

View file

@ -31,6 +31,92 @@ type PullProgress struct {
type GenerateRequest struct { type GenerateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
ModelOptions `json:"model_opts"`
PredictOptions `json:"predict_opts"`
}
type ModelOptions struct {
ContextSize int `json:"context_size"`
Seed int `json:"seed"`
NBatch int `json:"n_batch"`
F16Memory bool `json:"memory_f16"`
MLock bool `json:"mlock"`
MMap bool `json:"mmap"`
VocabOnly bool `json:"vocab_only"`
LowVRAM bool `json:"low_vram"`
Embeddings bool `json:"embeddings"`
NUMA bool `json:"numa"`
NGPULayers int `json:"gpu_layers"`
MainGPU string `json:"main_gpu"`
TensorSplit string `json:"tensor_split"`
}
type PredictOptions struct {
Seed int `json:"seed"`
Threads int `json:"threads"`
Tokens int `json:"tokens"`
TopK int `json:"top_k"`
Repeat int `json:"repeat"`
Batch int `json:"batch"`
NKeep int `json:"nkeep"`
TopP float64 `json:"top_p"`
Temperature float64 `json:"temp"`
Penalty float64 `json:"penalty"`
F16KV bool
DebugMode bool
StopPrompts []string
IgnoreEOS bool `json:"ignore_eos"`
TailFreeSamplingZ float64 `json:"tfs_z"`
TypicalP float64 `json:"typical_p"`
FrequencyPenalty float64 `json:"freq_penalty"`
PresencePenalty float64 `json:"pres_penalty"`
Mirostat int `json:"mirostat"`
MirostatETA float64 `json:"mirostat_lr"`
MirostatTAU float64 `json:"mirostat_ent"`
PenalizeNL bool `json:"penalize_nl"`
LogitBias string `json:"logit_bias"`
PathPromptCache string
MLock bool `json:"mlock"`
MMap bool `json:"mmap"`
PromptCacheAll bool
PromptCacheRO bool
MainGPU string
TensorSplit string
}
var DefaultModelOptions ModelOptions = ModelOptions{
ContextSize: 128,
Seed: 0,
F16Memory: true,
MLock: false,
Embeddings: true,
MMap: true,
LowVRAM: false,
}
var DefaultPredictOptions PredictOptions = PredictOptions{
Seed: -1,
Threads: -1,
Tokens: 512,
Penalty: 1.1,
Repeat: 64,
Batch: 512,
NKeep: 64,
TopK: 90,
TopP: 0.86,
TailFreeSamplingZ: 1.0,
TypicalP: 1.0,
Temperature: 0.8,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
Mirostat: 0,
MirostatTAU: 5.0,
MirostatETA: 0.1,
MMap: true,
StopPrompts: []string{"llama"},
} }
type GenerateResponse struct { type GenerateResponse struct {

View file

@ -42,9 +42,7 @@ type LLama struct {
contextSize int contextSize int
} }
func New(model string, opts ...ModelOption) (*LLama, error) { func New(model string, mo ModelOptions) (*LLama, error) {
mo := NewModelOptions(opts...)
modelPath := C.CString(model) modelPath := C.CString(model)
defer C.free(unsafe.Pointer(modelPath)) defer C.free(unsafe.Pointer(modelPath))
@ -108,9 +106,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
return nil return nil
} }
func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { func (l *LLama) Predict(text string, po PredictOptions) (string, error) {
po := NewPredictOptions(opts...)
if po.TokenCallback != nil { if po.TokenCallback != nil {
setCallback(l.ctx, po.TokenCallback) setCallback(l.ctx, po.TokenCallback)
} }

View file

@ -26,12 +26,9 @@ var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func generate(c *gin.Context) { func generate(c *gin.Context) {
// TODO: these should be request parameters
gpulayers := 1
tokens := 512
threads := runtime.NumCPU()
var req api.GenerateRequest var req api.GenerateRequest
req.ModelOptions = api.DefaultModelOptions
req.PredictOptions = api.DefaultPredictOptions
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return return
@ -41,7 +38,10 @@ func generate(c *gin.Context) {
req.Model = remoteModel.FullName() req.Model = remoteModel.FullName()
} }
model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) modelOpts := getModelOpts(req)
modelOpts.NGPULayers = 1 // hard-code this for now
model, err := llama.New(req.Model, modelOpts)
if err != nil { if err != nil {
fmt.Println("Loading the model failed:", err.Error()) fmt.Println("Loading the model failed:", err.Error())
return return
@ -65,13 +65,16 @@ func generate(c *gin.Context) {
} }
ch := make(chan string) ch := make(chan string)
model.SetTokenCallback(func(token string) bool {
ch <- token
return true
})
predictOpts := getPredictOpts(req)
go func() { go func() {
defer close(ch) defer close(ch)
_, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool { _, err := model.Predict(req.Prompt, predictOpts)
ch <- token
return true
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return return
} }
func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
var opts llama.ModelOptions
opts.ContextSize = req.ModelOptions.ContextSize
opts.Seed = req.ModelOptions.Seed
opts.F16Memory = req.ModelOptions.F16Memory
opts.MLock = req.ModelOptions.MLock
opts.Embeddings = req.ModelOptions.Embeddings
opts.MMap = req.ModelOptions.MMap
opts.LowVRAM = req.ModelOptions.LowVRAM
opts.NBatch = req.ModelOptions.NBatch
opts.VocabOnly = req.ModelOptions.VocabOnly
opts.NUMA = req.ModelOptions.NUMA
opts.NGPULayers = req.ModelOptions.NGPULayers
opts.MainGPU = req.ModelOptions.MainGPU
opts.TensorSplit = req.ModelOptions.TensorSplit
return opts
}
func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
var opts llama.PredictOptions
if req.PredictOptions.Threads == -1 {
opts.Threads = runtime.NumCPU()
} else {
opts.Threads = req.PredictOptions.Threads
}
opts.Seed = req.PredictOptions.Seed
opts.Tokens = req.PredictOptions.Tokens
opts.Penalty = req.PredictOptions.Penalty
opts.Repeat = req.PredictOptions.Repeat
opts.Batch = req.PredictOptions.Batch
opts.NKeep = req.PredictOptions.NKeep
opts.TopK = req.PredictOptions.TopK
opts.TopP = req.PredictOptions.TopP
opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
opts.TypicalP = req.PredictOptions.TypicalP
opts.Temperature = req.PredictOptions.Temperature
opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
opts.PresencePenalty = req.PredictOptions.PresencePenalty
opts.Mirostat = req.PredictOptions.Mirostat
opts.MirostatTAU = req.PredictOptions.MirostatTAU
opts.MirostatETA = req.PredictOptions.MirostatETA
opts.MMap = req.PredictOptions.MMap
return opts
}