diff --git a/cmd/interactive.go b/cmd/interactive.go
index c9836372..a421a513 100644
--- a/cmd/interactive.go
+++ b/cmd/interactive.go
@@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
if args[1] == "system" {
- opts.System = sb.String()
- opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
+ opts.System = sb.String() // for display in modelfile
+ newMessage := api.Message{Role: "system", Content: sb.String()}
+ // Check if the slice is not empty and the last message is from 'system'
+ if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
+ // Replace the last message
+ opts.Messages[len(opts.Messages)-1] = newMessage
+ } else {
+ opts.Messages = append(opts.Messages, newMessage)
+ }
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
diff --git a/server/prompt.go b/server/prompt.go
index c83075d9..6b684963 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
@@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
var p prompt
- // Set the first system prompt to the model's system prompt
- if system != "" {
- p.System = system
- }
-
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
diff --git a/server/prompt_test.go b/server/prompt_test.go
index 0ac8e314..75c02d7b 100644
--- a/server/prompt_test.go
+++ b/server/prompt_test.go
@@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
- system string
messages []api.Message
window int
want string
@@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
window: 1024,
want: "[INST] Hello [/INST]",
},
- {
- name: "with default system message",
- system: "You are a Wizard.",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
- messages: []api.Message{
- {Role: "user", Content: "Hello"},
- },
- window: 1024,
- want: "[INST] <>You are a Wizard.<> Hello [/INST]",
- },
{
name: "with system message",
template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
@@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
window: 1024,
want: "",
},
- {
- name: "empty list default system",
- system: "You are a Wizard.",
- template: "{{ .System }} {{ .Prompt }}",
- messages: []api.Message{},
- window: 1024,
- want: "You are a Wizard. ",
- },
- {
- name: "empty user message",
- system: "You are a Wizard.",
- template: "{{ .System }} {{ .Prompt }}",
- messages: []api.Message{
- {Role: "user", Content: ""},
- },
- window: 1024,
- want: "You are a Wizard. ",
- },
{
name: "empty prompt",
template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
@@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
- got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
+ got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}
diff --git a/server/routes.go b/server/routes.go
index 49ea33ac..c769a876 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
+func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
}
- prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
+ prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
@@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
- prompt, err := chatPrompt(c.Request.Context(), req.Messages)
+ // if the first message is not a system message, then add the model's default system message
+ if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
+ req.Messages = append([]api.Message{
+ {
+ Role: "system",
+ Content: model.System,
+ },
+ }, req.Messages...)
+ }
+
+ prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return