Merge pull request #5031 from ollama/mxyng/fix-multibyte-utf16

fix: multibyte utf16
This commit is contained in:
Michael Yang 2024-06-13 13:14:55 -07:00 committed by GitHub
commit 15a687ae4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 54 additions and 45 deletions

View file

@ -8,7 +8,9 @@ import (
"io"
"strconv"
"strings"
"unicode"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
)
type File struct {
@ -69,14 +71,11 @@ func ParseFile(r io.Reader) (*File, error) {
var b bytes.Buffer
var role string
var lineCount int
var linePos int
var utf16 bool
var f File
br := bufio.NewReader(r)
tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
br := bufio.NewReader(transform.NewReader(r, tr))
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
@ -85,17 +84,6 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err
}
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
utf16 = true
continue
}
// skip the second byte if we're reading utf16
if utf16 && r == 0 {
continue
}
next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String())
@ -103,13 +91,6 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err
}
if isNewline(r) {
lineCount++
linePos = 0
} else {
linePos++
}
// process the state transition, some transitions need to be intercepted and redirected
if next != curr {
switch curr {
@ -309,10 +290,6 @@ func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
func isUnreadable(r rune) bool {
return r == unicode.ReplacementChar
}
func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
}

View file

@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
)
func TestParseFileFile(t *testing.T) {
@ -517,14 +519,6 @@ PARAMETER param1 1
PARAMETER param2 4096
SYSTEM You are a utf16 file.
`
// simulate a utf16 le file
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File)
require.NoError(t, err)
actual, err := ParseFile(buf)
require.NoError(t, err)
expected := []Command{
{Name: "model", Args: "bob"},
@ -533,14 +527,52 @@ SYSTEM You are a utf16 file.
{Name: "system", Args: "You are a utf16 file."},
}
assert.Equal(t, expected, actual.Commands)
t.Run("le", func(t *testing.T) {
var b bytes.Buffer
require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
// simulate a utf16 be file
buf = new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, utf16File)
actual, err := ParseFile(&b)
require.NoError(t, err)
actual, err = ParseFile(buf)
assert.Equal(t, expected, actual.Commands)
})
t.Run("be", func(t *testing.T) {
var b bytes.Buffer
require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
actual, err := ParseFile(&b)
require.NoError(t, err)
assert.Equal(t, expected, actual.Commands)
})
}
func TestParseMultiByte(t *testing.T) {
input := `FROM test
SYSTEM 你好👋`
expect := []Command{
{Name: "model", Args: "test"},
{Name: "system", Args: "你好👋"},
}
encodings := []encoding.Encoding{
unicode.UTF8,
unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
}
for _, encoding := range encodings {
t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
s, err := encoding.NewEncoder().String(input)
require.NoError(t, err)
actual, err := ParseFile(strings.NewReader(s))
require.NoError(t, err)
assert.Equal(t, expect, actual.Commands)
})
}
}