Merge pull request #5196 from ollama/mxyng/messages-2

include modelfile messages
This commit is contained in:
Michael Yang 2024-07-31 10:18:17 -07:00 committed by GitHub
commit c4c84b7a0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 36 additions and 40 deletions

View file

@ -362,9 +362,24 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.MultiModal = slices.Contains(info.Details.Families, "clip") opts.MultiModal = slices.Contains(info.Details.Families, "clip")
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
opts.Messages = append(opts.Messages, info.Messages...)
if interactive { if interactive {
if err := loadModel(cmd, &opts); err != nil {
return err
}
for _, msg := range info.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
return generate(cmd, opts) return generate(cmd, opts)

View file

@ -48,29 +48,10 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
} }
return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
p.StopAndClear()
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
return nil
})
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
err := loadModel(cmd, &opts)
if err != nil {
return err
}
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables") fmt.Fprintln(os.Stderr, " /set Set session variables")

View file

@ -70,7 +70,7 @@ type Model struct {
License []string License []string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Messages []Message Messages []api.Message
Template *template.Template Template *template.Template
} }
@ -191,11 +191,6 @@ func (m *Model) String() string {
return modelfile.String() return modelfile.String()
} }
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ConfigV2 struct { type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`

View file

@ -164,17 +164,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
} }
var b bytes.Buffer
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
}
var values template.Values var values template.Values
if req.Suffix != "" { if req.Suffix != "" {
values.Prompt = prompt values.Prompt = prompt
@ -187,6 +176,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "system", Content: m.System}) msgs = append(msgs, api.Message{Role: "system", Content: m.System})
} }
if req.Context == nil {
msgs = append(msgs, m.Messages...)
}
for _, i := range images { for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
} }
@ -194,11 +187,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
} }
var b bytes.Buffer
if err := tmpl.Execute(&b, values); err != nil { if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
}
prompt = b.String() prompt = b.String()
} }
@ -1329,11 +1333,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" { if req.Messages[0].Role != "system" && m.System != "" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
} }
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return