fix: chat system prompting overrides (#2542)
This commit is contained in:
parent
9774663013
commit
88622847c6
4 changed files with 24 additions and 41 deletions
|
@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||
}
|
||||
|
||||
if args[1] == "system" {
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
opts.System = sb.String() // for display in modelfile
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||
} else {
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
|
|
|
@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
|||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
type prompt struct {
|
||||
System string
|
||||
Prompt string
|
||||
|
@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
|
|||
|
||||
var p prompt
|
||||
|
||||
// Set the first system prompt to the model's system prompt
|
||||
if system != "" {
|
||||
p.System = system
|
||||
}
|
||||
|
||||
// iterate through messages to build up {system,user,response} prompts
|
||||
var imgId int
|
||||
var prompts []prompt
|
||||
|
|
|
@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
system string
|
||||
messages []api.Message
|
||||
window int
|
||||
want string
|
||||
|
@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
|
|||
window: 1024,
|
||||
want: "[INST] Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with default system message",
|
||||
system: "You are a Wizard.",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with system message",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
|
@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
|
|||
window: 1024,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty list default system",
|
||||
system: "You are a Wizard.",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. ",
|
||||
},
|
||||
{
|
||||
name: "empty user message",
|
||||
system: "You are a Wizard.",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: ""},
|
||||
},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. ",
|
||||
},
|
||||
{
|
||||
name: "empty prompt",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||
|
@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
|||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
|
||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
|
|
|
@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return loaded.runner.Encode(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
|
|||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
|
||||
// if the first message is not a system message, then add the model's default system message
|
||||
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
||||
req.Messages = append([]api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: model.System,
|
||||
},
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
|
Loading…
Add table
Reference in a new issue