fix: chat system prompting overrides (#2542)
This commit is contained in:
parent
9774663013
commit
88622847c6
4 changed files with 24 additions and 41 deletions
|
@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if args[1] == "system" {
|
if args[1] == "system" {
|
||||||
opts.System = sb.String()
|
opts.System = sb.String() // for display in modelfile
|
||||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
|
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||||
|
// Replace the last message
|
||||||
|
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||||
|
} else {
|
||||||
|
opts.Messages = append(opts.Messages, newMessage)
|
||||||
|
}
|
||||||
fmt.Println("Set system message.")
|
fmt.Println("Set system message.")
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
} else if args[1] == "template" {
|
} else if args[1] == "template" {
|
||||||
|
|
|
@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||||
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||||
type prompt struct {
|
type prompt struct {
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
|
||||||
|
|
||||||
var p prompt
|
var p prompt
|
||||||
|
|
||||||
// Set the first system prompt to the model's system prompt
|
|
||||||
if system != "" {
|
|
||||||
p.System = system
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate through messages to build up {system,user,response} prompts
|
// iterate through messages to build up {system,user,response} prompts
|
||||||
var imgId int
|
var imgId int
|
||||||
var prompts []prompt
|
var prompts []prompt
|
||||||
|
|
|
@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
template string
|
template string
|
||||||
system string
|
|
||||||
messages []api.Message
|
messages []api.Message
|
||||||
window int
|
window int
|
||||||
want string
|
want string
|
||||||
|
@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
|
||||||
window: 1024,
|
window: 1024,
|
||||||
want: "[INST] Hello [/INST]",
|
want: "[INST] Hello [/INST]",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "with default system message",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
|
||||||
messages: []api.Message{
|
|
||||||
{Role: "user", Content: "Hello"},
|
|
||||||
},
|
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "with system message",
|
name: "with system message",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||||
|
@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
|
||||||
window: 1024,
|
window: 1024,
|
||||||
want: "",
|
want: "",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "empty list default system",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
|
||||||
messages: []api.Message{},
|
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty user message",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
|
||||||
messages: []api.Message{
|
|
||||||
{Role: "user", Content: ""},
|
|
||||||
},
|
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. ",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "empty prompt",
|
name: "empty prompt",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||||
|
@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
|
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
|
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||||
encode := func(s string) ([]int, error) {
|
encode := func(s string) ([]int, error) {
|
||||||
return loaded.runner.Encode(ctx, s)
|
return loaded.runner.Encode(ctx, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
|
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
|
// if the first message is not a system message, then add the model's default system message
|
||||||
|
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
||||||
|
req.Messages = append([]api.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: model.System,
|
||||||
|
},
|
||||||
|
}, req.Messages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in a new issue