include modelfile messages

This commit is contained in:
Michael Yang 2024-06-19 14:14:28 -07:00
parent f5e3939220
commit 15af558423
3 changed files with 19 additions and 20 deletions

View file

@ -362,7 +362,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.MultiModal = slices.Contains(info.Details.Families, "clip") opts.MultiModal = slices.Contains(info.Details.Families, "clip")
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
opts.Messages = append(opts.Messages, info.Messages...)
if interactive { if interactive {
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)

View file

@ -70,7 +70,7 @@ type Model struct {
License []string License []string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Messages []Message Messages []api.Message
Template *template.Template Template *template.Template
} }
@ -191,11 +191,6 @@ func (m *Model) String() string {
return modelfile.String() return modelfile.String()
} }
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ConfigV2 struct { type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`

View file

@ -164,17 +164,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
} }
var b bytes.Buffer
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
}
var values template.Values var values template.Values
if req.Suffix != "" { if req.Suffix != "" {
values.Prompt = prompt values.Prompt = prompt
@ -187,6 +176,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "system", Content: m.System}) msgs = append(msgs, api.Message{Role: "system", Content: m.System})
} }
if req.Context == nil {
msgs = append(msgs, m.Messages...)
}
for _, i := range images { for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
} }
@ -194,11 +187,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
} }
var b bytes.Buffer
if err := tmpl.Execute(&b, values); err != nil { if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
}
prompt = b.String() prompt = b.String()
} }
@ -1323,11 +1327,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" { if req.Messages[0].Role != "system" && m.System != "" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
} }
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return