From 88622847c6a83508681b8876e2aaca9ca85f83b5 Mon Sep 17 00:00:00 2001
From: Bruce MacDonald <brucewmacdonald@gmail.com>
Date: Fri, 16 Feb 2024 14:42:43 -0500
Subject: [PATCH] fix: chat system prompting overrides (#2542)

---
 cmd/interactive.go    | 11 +++++++++--
 server/prompt.go      |  7 +------
 server/prompt_test.go | 31 +------------------------------
 server/routes.go      | 16 +++++++++++++---
 4 files changed, 24 insertions(+), 41 deletions(-)

diff --git a/cmd/interactive.go b/cmd/interactive.go
index c9836372..a421a513 100644
--- a/cmd/interactive.go
+++ b/cmd/interactive.go
@@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 					}
 
 					if args[1] == "system" {
-						opts.System = sb.String()
-						opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
+						opts.System = sb.String() // for display in modelfile
+						newMessage := api.Message{Role: "system", Content: sb.String()}
+						// Check if the slice is not empty and the last message is from 'system'
+						if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
+							// Replace the last message
+							opts.Messages[len(opts.Messages)-1] = newMessage
+						} else {
+							opts.Messages = append(opts.Messages, newMessage)
+						}
 						fmt.Println("Set system message.")
 						sb.Reset()
 					} else if args[1] == "template" {
diff --git a/server/prompt.go b/server/prompt.go
index c83075d9..6b684963 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
 }
 
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
 	type prompt struct {
 		System   string
 		Prompt   string
@@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
 
 	var p prompt
 
-	// Set the first system prompt to the model's system prompt
-	if system != "" {
-		p.System = system
-	}
-
 	// iterate through messages to build up {system,user,response} prompts
 	var imgId int
 	var prompts []prompt
diff --git a/server/prompt_test.go b/server/prompt_test.go
index 0ac8e314..75c02d7b 100644
--- a/server/prompt_test.go
+++ b/server/prompt_test.go
@@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
 	tests := []struct {
 		name     string
 		template string
-		system   string
 		messages []api.Message
 		window   int
 		want     string
@@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
 			window: 1024,
 			want:   "[INST] Hello [/INST]",
 		},
-		{
-			name:     "with default system message",
-			system:   "You are a Wizard.",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "user", Content: "Hello"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
-		},
 		{
 			name:     "with system message",
 			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
@@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
 			window:   1024,
 			want:     "",
 		},
-		{
-			name:     "empty list default system",
-			system:   "You are a Wizard.",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{},
-			window:   1024,
-			want:     "You are a Wizard. ",
-		},
-		{
-			name:     "empty user message",
-			system:   "You are a Wizard.",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "user", Content: ""},
-			},
-			window: 1024,
-			want:   "You are a Wizard. ",
-		},
 		{
 			name:     "empty prompt",
 			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
@@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
+			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
 			if err != nil {
 				t.Errorf("error = %v", err)
 			}
diff --git a/server/routes.go b/server/routes.go
index 49ea33ac..c769a876 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
+func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 		return loaded.runner.Encode(ctx, s)
 	}
 
-	prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
+	prompt, err := ChatPrompt(template, messages, numCtx, encode)
 	if err != nil {
 		return "", err
 	}
@@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
-	prompt, err := chatPrompt(c.Request.Context(), req.Messages)
+	// if the first message is not a system message, then add the model's default system message
+	if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
+		req.Messages = append([]api.Message{
+			{
+				Role:    "system",
+				Content: model.System,
+			},
+		}, req.Messages...)
+	}
+
+	prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return