From 15af5584238c17ae21853e7619e8008078e6e792 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Jun 2024 14:14:28 -0700 Subject: [PATCH 1/2] include modelfile messages --- cmd/cmd.go | 1 - server/images.go | 7 +------ server/routes.go | 31 ++++++++++++++++++------------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index b761d018..641afafb 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -362,7 +362,6 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.MultiModal = slices.Contains(info.Details.Families, "clip") opts.ParentModel = info.Details.ParentModel - opts.Messages = append(opts.Messages, info.Messages...) if interactive { return generateInteractive(cmd, opts) diff --git a/server/images.go b/server/images.go index 836dbcc2..0f616551 100644 --- a/server/images.go +++ b/server/images.go @@ -70,7 +70,7 @@ type Model struct { License []string Digest string Options map[string]interface{} - Messages []Message + Messages []api.Message Template *template.Template } @@ -191,11 +191,6 @@ func (m *Model) String() string { return modelfile.String() } -type Message struct { - Role string `json:"role"` - Content string `json:"content"` -} - type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` diff --git a/server/routes.go b/server/routes.go index e6ffe526..2b4d5794 100644 --- a/server/routes.go +++ b/server/routes.go @@ -164,17 +164,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - var b bytes.Buffer - if req.Context != nil { - s, err := r.Detokenize(c.Request.Context(), req.Context) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - b.WriteString(s) - } - var values template.Values if req.Suffix != "" { values.Prompt = prompt @@ -187,6 +176,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, api.Message{Role: "system", Content: m.System}) } + if req.Context == nil { + msgs = append(msgs, m.Messages...) + } + for _, i := range images { msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) } @@ -194,11 +187,22 @@ func (s *Server) GenerateHandler(c *gin.Context) { values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) } + var b bytes.Buffer if err := tmpl.Execute(&b, values); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + if req.Context != nil { + s, err := r.Detokenize(c.Request.Context(), req.Context) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + b.WriteString(s) + } + prompt = b.String() } @@ -1323,11 +1327,12 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + msgs := append(m.Messages, req.Messages...) if req.Messages[0].Role != "system" && m.System != "" { - req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) + msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) } - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return From a250c2cb13fd74b516dd138daad9ca54e30a9fab Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jul 2024 13:39:38 -0700 Subject: [PATCH 2/2] display messages --- cmd/cmd.go | 16 ++++++++++++++++ cmd/interactive.go | 27 ++++----------------------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 641afafb..22950885 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -364,6 +364,22 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.ParentModel = info.Details.ParentModel if interactive { + if err := loadModel(cmd, &opts); err != nil { + return err + } + + for _, msg := range info.Messages { + switch msg.Role { + case "user": + fmt.Printf(">>> %s\n", msg.Content) + case "assistant": + state := &displayResponseState{} + displayResponse(msg.Content, opts.WordWrap, state) + fmt.Println() + fmt.Println() + } + } + return generateInteractive(cmd, opts) } return generate(cmd, opts) diff --git a/cmd/interactive.go b/cmd/interactive.go index adbc3e9f..41b19971 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -46,29 +46,10 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error { KeepAlive: opts.KeepAlive, } - return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { - p.StopAndClear() - for _, msg := range opts.Messages { - switch msg.Role { - case "user": - fmt.Printf(">>> %s\n", msg.Content) - case "assistant": - state := &displayResponseState{} - displayResponse(msg.Content, opts.WordWrap, state) - fmt.Println() - fmt.Println() - } - } - return nil - }) + return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil }) } func generateInteractive(cmd *cobra.Command, opts runOptions) error { - err := loadModel(cmd, &opts) - if err != nil { - return err - } - usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, " /set Set session variables") @@ -375,9 +356,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } req := &api.ShowRequest{ - Name: opts.Model, - System: opts.System, - Options: opts.Options, + Name: opts.Model, + System: opts.System, + Options: opts.Options, } resp, err := client.Show(cmd.Context(), req) if err != nil {