From 9e35d9bbee4c96ca064bcb7eadc5b2eb3a200ce7 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 15 Jul 2024 13:55:57 -0700 Subject: [PATCH] server: lowercase roles for compatibility with clients (#5695) --- api/types.go | 16 ++++++++++++++-- api/types_test.go | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index bf552928..3b607cec 100644 --- a/api/types.go +++ b/api/types.go @@ -110,6 +110,18 @@ type Message struct { Images []ImageData `json:"images,omitempty"` } +func (m *Message) UnmarshalJSON(b []byte) error { + type Alias Message + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + + *m = Message(a) + m.Role = strings.ToLower(m.Role) + return nil +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { @@ -243,8 +255,8 @@ type DeleteRequest struct { // ShowRequest is the request passed to [Client.Show]. type ShowRequest struct { - Model string `json:"model"` - System string `json:"system"` + Model string `json:"model"` + System string `json:"system"` // Template is deprecated Template string `json:"template"` diff --git a/api/types_test.go b/api/types_test.go index c60ed90e..4699c150 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) { }) } } + +func TestMessage_UnmarshalJSON(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {`{"role": "USER", "content": "Hello!"}`, "user"}, + {`{"role": "System", "content": "Initialization complete."}`, "system"}, + {`{"role": "assistant", "content": "How can I help you?"}`, "assistant"}, + {`{"role": "TOOl", "content": "Access granted."}`, "tool"}, + } + + for _, test := range tests { + var msg Message + if err := json.Unmarshal([]byte(test.input), &msg); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if msg.Role != test.expected { + t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected) + } + } +}