fix: chat system prompting overrides (#2542)

This commit is contained in:
Bruce MacDonald 2024-02-16 14:42:43 -05:00 committed by GitHub
parent 9774663013
commit 88622847c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 24 additions and 41 deletions

View file

@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
if args[1] == "system" {
opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {

View file

@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
var p prompt
// Set the first system prompt to the model's system prompt
if system != "" {
p.System = system
}
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt

View file

@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
messages []api.Message
window int
want string
@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with default system message",
system: "You are a Wizard.",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
window: 1024,
want: "",
},
{
name: "empty list default system",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty user message",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}

View file

@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
}
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return