Move the parser back + handle utf16 files (#4533)

This commit is contained in:
Patrick Devine 2024-05-20 11:26:45 -07:00 committed by GitHub
parent 63a453554d
commit ccdf0b2a44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 84 additions and 17 deletions

View file

@ -35,6 +35,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
@ -63,7 +64,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
defer f.Close() defer f.Close()
modelfile, err := model.ParseFile(f) modelfile, err := parser.ParseFile(f)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,4 +1,4 @@
package model package parser
import ( import (
"bufio" "bufio"
@ -8,6 +8,7 @@ import (
"io" "io"
"strconv" "strconv"
"strings" "strings"
"unicode"
) )
type File struct { type File struct {
@ -68,6 +69,11 @@ func ParseFile(r io.Reader) (*File, error) {
var b bytes.Buffer var b bytes.Buffer
var role string var role string
var lineCount int
var linePos int
var utf16 bool
var f File var f File
br := bufio.NewReader(r) br := bufio.NewReader(r)
@ -79,6 +85,17 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err return nil, err
} }
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
utf16 = true
continue
}
// skip the second byte if we're reading utf16
if utf16 && r == 0 {
continue
}
next, r, err := parseRuneForState(r, curr) next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) { if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String()) return nil, fmt.Errorf("%w: %s", err, b.String())
@ -86,6 +103,13 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err return nil, err
} }
if isNewline(r) {
lineCount++
linePos = 0
} else {
linePos++
}
// process the state transition, some transitions need to be intercepted and redirected // process the state transition, some transitions need to be intercepted and redirected
if next != curr { if next != curr {
switch curr { switch curr {
@ -285,6 +309,10 @@ func isNewline(r rune) bool {
return r == '\r' || r == '\n' return r == '\r' || r == '\n'
} }
func isUnreadable(r rune) bool {
return r == unicode.ReplacementChar
}
func isValidMessageRole(role string) bool { func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant" return role == "system" || role == "user" || role == "assistant"
} }

View file

@ -1,11 +1,13 @@
package model package parser
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"testing" "testing"
"unicode/utf16"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -509,3 +511,37 @@ SYSTEM ""
} }
} }
func TestParseFileUTF16ParseFile(t *testing.T) {
data := `FROM bob
PARAMETER param1 1
PARAMETER param2 4096
SYSTEM You are a utf16 file.
`
// simulate a utf16 le file
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File)
assert.NoError(t, err)
actual, err := ParseFile(buf)
assert.NoError(t, err)
expected := []Command{
{Name: "model", Args: "bob"},
{Name: "param1", Args: "1"},
{Name: "param2", Args: "4096"},
{Name: "system", Args: "You are a utf16 file."},
}
assert.Equal(t, expected, actual.Commands)
// simulate a utf16 be file
buf = new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, utf16File)
assert.NoError(t, err)
actual, err = ParseFile(buf)
assert.NoError(t, err)
assert.Equal(t, expected, actual.Commands)
}

View file

@ -27,6 +27,7 @@ import (
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@ -61,36 +62,36 @@ func (m *Model) IsEmbedding() bool {
} }
func (m *Model) String() string { func (m *Model) String() string {
var modelfile model.File var modelfile parser.File
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "model", Name: "model",
Args: m.ModelPath, Args: m.ModelPath,
}) })
for _, adapter := range m.AdapterPaths { for _, adapter := range m.AdapterPaths {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "adapter", Name: "adapter",
Args: adapter, Args: adapter,
}) })
} }
for _, projector := range m.ProjectorPaths { for _, projector := range m.ProjectorPaths {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "model", Name: "model",
Args: projector, Args: projector,
}) })
} }
if m.Template != "" { if m.Template != "" {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template", Name: "template",
Args: m.Template, Args: m.Template,
}) })
} }
if m.System != "" { if m.System != "" {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "system", Name: "system",
Args: m.System, Args: m.System,
}) })
@ -100,13 +101,13 @@ func (m *Model) String() string {
switch v := v.(type) { switch v := v.(type) {
case []any: case []any:
for _, s := range v { for _, s := range v {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: k, Name: k,
Args: fmt.Sprintf("%v", s), Args: fmt.Sprintf("%v", s),
}) })
} }
default: default:
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: k, Name: k,
Args: fmt.Sprintf("%v", v), Args: fmt.Sprintf("%v", v),
}) })
@ -114,14 +115,14 @@ func (m *Model) String() string {
} }
for _, license := range m.License { for _, license := range m.License {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "license", Name: "license",
Args: license, Args: license,
}) })
} }
for _, msg := range m.Messages { for _, msg := range m.Messages {
modelfile.Commands = append(modelfile.Commands, model.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "message", Name: "message",
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content), Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
}) })
@ -314,7 +315,7 @@ func realpath(rel, from string) string {
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",

View file

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@ -539,7 +540,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
r = f r = f
} }
modelfile, err := model.ParseFile(r) modelfile, err := parser.ParseFile(r)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return

View file

@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -56,7 +56,7 @@ func Test_Routes(t *testing.T) {
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := model.ParseFile(r) modelfile, err := parser.ParseFile(r)
assert.Nil(t, err) assert.Nil(t, err)
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)