Merge pull request #5031 from ollama/mxyng/fix-multibyte-utf16

fix: multibyte utf16
This commit is contained in:
Michael Yang 2024-06-13 13:14:55 -07:00 committed by GitHub
commit 15a687ae4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 54 additions and 45 deletions

View file

@ -8,7 +8,9 @@ import (
"io" "io"
"strconv" "strconv"
"strings" "strings"
"unicode"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
) )
type File struct { type File struct {
@ -69,14 +71,11 @@ func ParseFile(r io.Reader) (*File, error) {
var b bytes.Buffer var b bytes.Buffer
var role string var role string
var lineCount int
var linePos int
var utf16 bool
var f File var f File
br := bufio.NewReader(r) tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
br := bufio.NewReader(transform.NewReader(r, tr))
for { for {
r, _, err := br.ReadRune() r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -85,17 +84,6 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err return nil, err
} }
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
utf16 = true
continue
}
// skip the second byte if we're reading utf16
if utf16 && r == 0 {
continue
}
next, r, err := parseRuneForState(r, curr) next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) { if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String()) return nil, fmt.Errorf("%w: %s", err, b.String())
@ -103,13 +91,6 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err return nil, err
} }
if isNewline(r) {
lineCount++
linePos = 0
} else {
linePos++
}
// process the state transition, some transitions need to be intercepted and redirected // process the state transition, some transitions need to be intercepted and redirected
if next != curr { if next != curr {
switch curr { switch curr {
@ -309,10 +290,6 @@ func isNewline(r rune) bool {
return r == '\r' || r == '\n' return r == '\r' || r == '\n'
} }
func isUnreadable(r rune) bool {
return r == unicode.ReplacementChar
}
func isValidMessageRole(role string) bool { func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant" return role == "system" || role == "user" || role == "assistant"
} }

View file

@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
) )
func TestParseFileFile(t *testing.T) { func TestParseFileFile(t *testing.T) {
@ -517,14 +519,6 @@ PARAMETER param1 1
PARAMETER param2 4096 PARAMETER param2 4096
SYSTEM You are a utf16 file. SYSTEM You are a utf16 file.
` `
// simulate a utf16 le file
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File)
require.NoError(t, err)
actual, err := ParseFile(buf)
require.NoError(t, err)
expected := []Command{ expected := []Command{
{Name: "model", Args: "bob"}, {Name: "model", Args: "bob"},
@ -533,14 +527,52 @@ SYSTEM You are a utf16 file.
{Name: "system", Args: "You are a utf16 file."}, {Name: "system", Args: "You are a utf16 file."},
} }
assert.Equal(t, expected, actual.Commands) t.Run("le", func(t *testing.T) {
var b bytes.Buffer
require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
// simulate a utf16 be file actual, err := ParseFile(&b)
buf = new(bytes.Buffer) require.NoError(t, err)
err = binary.Write(buf, binary.BigEndian, utf16File)
require.NoError(t, err)
actual, err = ParseFile(buf) assert.Equal(t, expected, actual.Commands)
require.NoError(t, err) })
assert.Equal(t, expected, actual.Commands)
t.Run("be", func(t *testing.T) {
var b bytes.Buffer
require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
actual, err := ParseFile(&b)
require.NoError(t, err)
assert.Equal(t, expected, actual.Commands)
})
}
func TestParseMultiByte(t *testing.T) {
input := `FROM test
SYSTEM 你好👋`
expect := []Command{
{Name: "model", Args: "test"},
{Name: "system", Args: "你好👋"},
}
encodings := []encoding.Encoding{
unicode.UTF8,
unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
}
for _, encoding := range encodings {
t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
s, err := encoding.NewEncoder().String(input)
require.NoError(t, err)
actual, err := ParseFile(strings.NewReader(s))
require.NoError(t, err)
assert.Equal(t, expect, actual.Commands)
})
}
} }