Merge pull request #52 from jmorganca/go-opts
pass model and predict options
This commit is contained in:
commit
3d73ad0c56
3 changed files with 151 additions and 16 deletions
86
api/types.go
86
api/types.go
|
@ -31,6 +31,92 @@ type PullProgress struct {
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
ModelOptions `json:"model_opts"`
|
||||||
|
PredictOptions `json:"predict_opts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelOptions struct {
|
||||||
|
ContextSize int `json:"context_size"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NBatch int `json:"n_batch"`
|
||||||
|
F16Memory bool `json:"memory_f16"`
|
||||||
|
MLock bool `json:"mlock"`
|
||||||
|
MMap bool `json:"mmap"`
|
||||||
|
VocabOnly bool `json:"vocab_only"`
|
||||||
|
LowVRAM bool `json:"low_vram"`
|
||||||
|
Embeddings bool `json:"embeddings"`
|
||||||
|
NUMA bool `json:"numa"`
|
||||||
|
NGPULayers int `json:"gpu_layers"`
|
||||||
|
MainGPU string `json:"main_gpu"`
|
||||||
|
TensorSplit string `json:"tensor_split"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PredictOptions struct {
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
Threads int `json:"threads"`
|
||||||
|
Tokens int `json:"tokens"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
Repeat int `json:"repeat"`
|
||||||
|
Batch int `json:"batch"`
|
||||||
|
NKeep int `json:"nkeep"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Temperature float64 `json:"temp"`
|
||||||
|
Penalty float64 `json:"penalty"`
|
||||||
|
F16KV bool
|
||||||
|
DebugMode bool
|
||||||
|
StopPrompts []string
|
||||||
|
IgnoreEOS bool `json:"ignore_eos"`
|
||||||
|
|
||||||
|
TailFreeSamplingZ float64 `json:"tfs_z"`
|
||||||
|
TypicalP float64 `json:"typical_p"`
|
||||||
|
FrequencyPenalty float64 `json:"freq_penalty"`
|
||||||
|
PresencePenalty float64 `json:"pres_penalty"`
|
||||||
|
Mirostat int `json:"mirostat"`
|
||||||
|
MirostatETA float64 `json:"mirostat_lr"`
|
||||||
|
MirostatTAU float64 `json:"mirostat_ent"`
|
||||||
|
PenalizeNL bool `json:"penalize_nl"`
|
||||||
|
LogitBias string `json:"logit_bias"`
|
||||||
|
|
||||||
|
PathPromptCache string
|
||||||
|
MLock bool `json:"mlock"`
|
||||||
|
MMap bool `json:"mmap"`
|
||||||
|
PromptCacheAll bool
|
||||||
|
PromptCacheRO bool
|
||||||
|
MainGPU string
|
||||||
|
TensorSplit string
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultModelOptions ModelOptions = ModelOptions{
|
||||||
|
ContextSize: 128,
|
||||||
|
Seed: 0,
|
||||||
|
F16Memory: true,
|
||||||
|
MLock: false,
|
||||||
|
Embeddings: true,
|
||||||
|
MMap: true,
|
||||||
|
LowVRAM: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultPredictOptions PredictOptions = PredictOptions{
|
||||||
|
Seed: -1,
|
||||||
|
Threads: -1,
|
||||||
|
Tokens: 512,
|
||||||
|
Penalty: 1.1,
|
||||||
|
Repeat: 64,
|
||||||
|
Batch: 512,
|
||||||
|
NKeep: 64,
|
||||||
|
TopK: 90,
|
||||||
|
TopP: 0.86,
|
||||||
|
TailFreeSamplingZ: 1.0,
|
||||||
|
TypicalP: 1.0,
|
||||||
|
Temperature: 0.8,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
Mirostat: 0,
|
||||||
|
MirostatTAU: 5.0,
|
||||||
|
MirostatETA: 0.1,
|
||||||
|
MMap: true,
|
||||||
|
StopPrompts: []string{"llama"},
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
|
|
|
@ -42,9 +42,7 @@ type LLama struct {
|
||||||
contextSize int
|
contextSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(model string, opts ...ModelOption) (*LLama, error) {
|
func New(model string, mo ModelOptions) (*LLama, error) {
|
||||||
mo := NewModelOptions(opts...)
|
|
||||||
|
|
||||||
modelPath := C.CString(model)
|
modelPath := C.CString(model)
|
||||||
defer C.free(unsafe.Pointer(modelPath))
|
defer C.free(unsafe.Pointer(modelPath))
|
||||||
|
|
||||||
|
@ -108,9 +106,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
|
func (l *LLama) Predict(text string, po PredictOptions) (string, error) {
|
||||||
po := NewPredictOptions(opts...)
|
|
||||||
|
|
||||||
if po.TokenCallback != nil {
|
if po.TokenCallback != nil {
|
||||||
setCallback(l.ctx, po.TokenCallback)
|
setCallback(l.ctx, po.TokenCallback)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,12 +26,9 @@ var templatesFS embed.FS
|
||||||
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
||||||
|
|
||||||
func generate(c *gin.Context) {
|
func generate(c *gin.Context) {
|
||||||
// TODO: these should be request parameters
|
|
||||||
gpulayers := 1
|
|
||||||
tokens := 512
|
|
||||||
threads := runtime.NumCPU()
|
|
||||||
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
|
req.ModelOptions = api.DefaultModelOptions
|
||||||
|
req.PredictOptions = api.DefaultPredictOptions
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -41,7 +38,10 @@ func generate(c *gin.Context) {
|
||||||
req.Model = remoteModel.FullName()
|
req.Model = remoteModel.FullName()
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
|
modelOpts := getModelOpts(req)
|
||||||
|
modelOpts.NGPULayers = 1 // hard-code this for now
|
||||||
|
|
||||||
|
model, err := llama.New(req.Model, modelOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Loading the model failed:", err.Error())
|
fmt.Println("Loading the model failed:", err.Error())
|
||||||
return
|
return
|
||||||
|
@ -65,13 +65,16 @@ func generate(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan string)
|
ch := make(chan string)
|
||||||
|
model.SetTokenCallback(func(token string) bool {
|
||||||
|
ch <- token
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
predictOpts := getPredictOpts(req)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
_, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
|
_, err := model.Predict(req.Prompt, predictOpts)
|
||||||
ch <- token
|
|
||||||
return true
|
|
||||||
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
|
||||||
|
var opts llama.ModelOptions
|
||||||
|
opts.ContextSize = req.ModelOptions.ContextSize
|
||||||
|
opts.Seed = req.ModelOptions.Seed
|
||||||
|
opts.F16Memory = req.ModelOptions.F16Memory
|
||||||
|
opts.MLock = req.ModelOptions.MLock
|
||||||
|
opts.Embeddings = req.ModelOptions.Embeddings
|
||||||
|
opts.MMap = req.ModelOptions.MMap
|
||||||
|
opts.LowVRAM = req.ModelOptions.LowVRAM
|
||||||
|
|
||||||
|
opts.NBatch = req.ModelOptions.NBatch
|
||||||
|
opts.VocabOnly = req.ModelOptions.VocabOnly
|
||||||
|
opts.NUMA = req.ModelOptions.NUMA
|
||||||
|
opts.NGPULayers = req.ModelOptions.NGPULayers
|
||||||
|
opts.MainGPU = req.ModelOptions.MainGPU
|
||||||
|
opts.TensorSplit = req.ModelOptions.TensorSplit
|
||||||
|
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
|
||||||
|
var opts llama.PredictOptions
|
||||||
|
|
||||||
|
if req.PredictOptions.Threads == -1 {
|
||||||
|
opts.Threads = runtime.NumCPU()
|
||||||
|
} else {
|
||||||
|
opts.Threads = req.PredictOptions.Threads
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.Seed = req.PredictOptions.Seed
|
||||||
|
opts.Tokens = req.PredictOptions.Tokens
|
||||||
|
opts.Penalty = req.PredictOptions.Penalty
|
||||||
|
opts.Repeat = req.PredictOptions.Repeat
|
||||||
|
opts.Batch = req.PredictOptions.Batch
|
||||||
|
opts.NKeep = req.PredictOptions.NKeep
|
||||||
|
opts.TopK = req.PredictOptions.TopK
|
||||||
|
opts.TopP = req.PredictOptions.TopP
|
||||||
|
opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
|
||||||
|
opts.TypicalP = req.PredictOptions.TypicalP
|
||||||
|
opts.Temperature = req.PredictOptions.Temperature
|
||||||
|
opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
|
||||||
|
opts.PresencePenalty = req.PredictOptions.PresencePenalty
|
||||||
|
opts.Mirostat = req.PredictOptions.Mirostat
|
||||||
|
opts.MirostatTAU = req.PredictOptions.MirostatTAU
|
||||||
|
opts.MirostatETA = req.PredictOptions.MirostatETA
|
||||||
|
opts.MMap = req.PredictOptions.MMap
|
||||||
|
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue