From ccdf0b2a449d812a3708a3083f6a725289f4f750 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 20 May 2024 11:26:45 -0700 Subject: [PATCH] Move the parser back + handle utf16 files (#4533) --- cmd/cmd.go | 3 +- types/model/file.go => parser/parser.go | 30 ++++++++++++++- .../file_test.go => parser/parser_test.go | 38 ++++++++++++++++++- server/images.go | 23 +++++------ server/routes.go | 3 +- server/routes_test.go | 4 +- 6 files changed, 84 insertions(+), 17 deletions(-) rename types/model/file.go => parser/parser.go (92%) rename types/model/file_test.go => parser/parser_test.go (92%) diff --git a/cmd/cmd.go b/cmd/cmd.go index 3b60334c..f79f8b97 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/errtypes" @@ -63,7 +64,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } defer f.Close() - modelfile, err := model.ParseFile(f) + modelfile, err := parser.ParseFile(f) if err != nil { return err } diff --git a/types/model/file.go b/parser/parser.go similarity index 92% rename from types/model/file.go rename to parser/parser.go index ee398309..4f44f6af 100644 --- a/types/model/file.go +++ b/parser/parser.go @@ -1,4 +1,4 @@ -package model +package parser import ( "bufio" @@ -8,6 +8,7 @@ import ( "io" "strconv" "strings" + "unicode" ) type File struct { @@ -68,6 +69,11 @@ func ParseFile(r io.Reader) (*File, error) { var b bytes.Buffer var role string + var lineCount int + var linePos int + + var utf16 bool + var f File br := bufio.NewReader(r) @@ -79,6 +85,17 @@ func ParseFile(r io.Reader) (*File, error) { return nil, err } + // the utf16 byte order mark will be read as "unreadable" by ReadRune() + if isUnreadable(r) && lineCount == 0 && linePos == 0 { + utf16 = true + continue + } + + // skip the second byte if we're reading utf16 + if utf16 && r == 0 { + continue + } + next, r, err := parseRuneForState(r, curr) if errors.Is(err, io.ErrUnexpectedEOF) { return nil, fmt.Errorf("%w: %s", err, b.String()) @@ -86,6 +103,13 @@ func ParseFile(r io.Reader) (*File, error) { return nil, err } + if isNewline(r) { + lineCount++ + linePos = 0 + } else { + linePos++ + } + // process the state transition, some transitions need to be intercepted and redirected if next != curr { switch curr { @@ -285,6 +309,10 @@ func isNewline(r rune) bool { return r == '\r' || r == '\n' } +func isUnreadable(r rune) bool { + return r == unicode.ReplacementChar +} + func isValidMessageRole(role string) bool { return role == "system" || role == "user" || role == "assistant" } diff --git a/types/model/file_test.go b/parser/parser_test.go similarity index 92% rename from types/model/file_test.go rename to parser/parser_test.go index 8e71760c..21223cb1 100644 --- a/types/model/file_test.go +++ b/parser/parser_test.go @@ -1,11 +1,13 @@ -package model +package parser import ( "bytes" + "encoding/binary" "fmt" "io" "strings" "testing" + "unicode/utf16" "github.com/stretchr/testify/assert" ) @@ -509,3 +511,37 @@ SYSTEM "" } } + +func TestParseFileUTF16ParseFile(t *testing.T) { + data := `FROM bob +PARAMETER param1 1 +PARAMETER param2 4096 +SYSTEM You are a utf16 file. +` + // simulate a utf16 le file + utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...)) + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.LittleEndian, utf16File) + assert.NoError(t, err) + + actual, err := ParseFile(buf) + assert.NoError(t, err) + + expected := []Command{ + {Name: "model", Args: "bob"}, + {Name: "param1", Args: "1"}, + {Name: "param2", Args: "4096"}, + {Name: "system", Args: "You are a utf16 file."}, + } + + assert.Equal(t, expected, actual.Commands) + + // simulate a utf16 be file + buf = new(bytes.Buffer) + err = binary.Write(buf, binary.BigEndian, utf16File) + assert.NoError(t, err) + + actual, err = ParseFile(buf) + assert.NoError(t, err) + assert.Equal(t, expected, actual.Commands) +} diff --git a/server/images.go b/server/images.go index 3f415b6d..0ccc90b9 100644 --- a/server/images.go +++ b/server/images.go @@ -27,6 +27,7 @@ import ( "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" @@ -61,36 +62,36 @@ func (m *Model) IsEmbedding() bool { } func (m *Model) String() string { - var modelfile model.File + var modelfile parser.File - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "model", Args: m.ModelPath, }) for _, adapter := range m.AdapterPaths { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "adapter", Args: adapter, }) } for _, projector := range m.ProjectorPaths { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "model", Args: projector, }) } if m.Template != "" { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "template", Args: m.Template, }) } if m.System != "" { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "system", Args: m.System, }) @@ -100,13 +101,13 @@ func (m *Model) String() string { switch v := v.(type) { case []any: for _, s := range v { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: k, Args: fmt.Sprintf("%v", s), }) } default: - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: k, Args: fmt.Sprintf("%v", v), }) @@ -114,14 +115,14 @@ func (m *Model) String() string { } for _, license := range m.License { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "license", Args: license, }) } for _, msg := range m.Messages { - modelfile.Commands = append(modelfile.Commands, model.Command{ + modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content), }) @@ -314,7 +315,7 @@ func realpath(rel, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { +func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", diff --git a/server/routes.go b/server/routes.go index 5fbc2b54..fff228f3 100644 --- a/server/routes.go +++ b/server/routes.go @@ -29,6 +29,7 @@ import ( "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" @@ -539,7 +540,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { r = f } - modelfile, err := model.ParseFile(r) + modelfile, err := parser.ParseFile(r) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/server/routes_test.go b/server/routes_test.go index e144c957..a48819fe 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/types/model" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/version" ) @@ -56,7 +56,7 @@ func Test_Routes(t *testing.T) { fname := createTestFile(t, "ollama-model") r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) - modelfile, err := model.ParseFile(r) + modelfile, err := parser.ParseFile(r) assert.Nil(t, err) fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status)