From 88622847c6a83508681b8876e2aaca9ca85f83b5 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald <brucewmacdonald@gmail.com> Date: Fri, 16 Feb 2024 14:42:43 -0500 Subject: [PATCH] fix: chat system prompting overrides (#2542) --- cmd/interactive.go | 11 +++++++++-- server/prompt.go | 7 +------ server/prompt_test.go | 31 +------------------------------ server/routes.go | 16 +++++++++++++--- 4 files changed, 24 insertions(+), 41 deletions(-) diff --git a/cmd/interactive.go b/cmd/interactive.go index c9836372..a421a513 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } if args[1] == "system" { - opts.System = sb.String() - opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) + opts.System = sb.String() // for display in modelfile + newMessage := api.Message{Role: "system", Content: sb.String()} + // Check if the slice is not empty and the last message is from 'system' + if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { + // Replace the last message + opts.Messages[len(opts.Messages)-1] = newMessage + } else { + opts.Messages = append(opts.Messages, newMessage) + } fmt.Println("Set system message.") sb.Reset() } else if args[1] == "template" { diff --git a/server/prompt.go b/server/prompt.go index c83075d9..6b684963 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc } // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size -func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { +func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { type prompt struct { System string Prompt string @@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int, var p prompt - // Set the first system prompt to the model's system prompt - if system != "" { - p.System = system - } - // iterate through messages to build up {system,user,response} prompts var imgId int var prompts []prompt diff --git a/server/prompt_test.go b/server/prompt_test.go index 0ac8e314..75c02d7b 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) { tests := []struct { name string template string - system string messages []api.Message window int want string @@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) { window: 1024, want: "[INST] Hello [/INST]", }, - { - name: "with default system message", - system: "You are a Wizard.", - template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", - messages: []api.Message{ - {Role: "user", Content: "Hello"}, - }, - window: 1024, - want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]", - }, { name: "with system message", template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", @@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) { window: 1024, want: "", }, - { - name: "empty list default system", - system: "You are a Wizard.", - template: "{{ .System }} {{ .Prompt }}", - messages: []api.Message{}, - window: 1024, - want: "You are a Wizard. ", - }, - { - name: "empty user message", - system: "You are a Wizard.", - template: "{{ .System }} {{ .Prompt }}", - messages: []api.Message{ - {Role: "user", Content: ""}, - }, - window: 1024, - want: "You are a Wizard. ", - }, { name: "empty prompt", template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", @@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode) + got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) if err != nil { t.Errorf("error = %v", err) } diff --git a/server/routes.go b/server/routes.go index 49ea33ac..c769a876 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) { } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, messages []api.Message) (string, error) { +func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { encode := func(s string) ([]int, error) { return loaded.runner.Encode(ctx, s) } - prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode) + prompt, err := ChatPrompt(template, messages, numCtx, encode) if err != nil { return "", err } @@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt, err := chatPrompt(c.Request.Context(), req.Messages) + // if the first message is not a system message, then add the model's default system message + if len(req.Messages) > 0 && req.Messages[0].Role != "system" { + req.Messages = append([]api.Message{ + { + Role: "system", + Content: model.System, + }, + }, req.Messages...) + } + + prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return