Merge pull request #5031 from ollama/mxyng/fix-multibyte-utf16
fix: multibyte utf16
This commit is contained in:
commit
15a687ae4b
2 changed files with 54 additions and 45 deletions
|
@ -8,7 +8,9 @@ import (
|
|||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
|
@ -69,14 +71,11 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
var lineCount int
|
||||
var linePos int
|
||||
|
||||
var utf16 bool
|
||||
|
||||
var f File
|
||||
|
||||
br := bufio.NewReader(r)
|
||||
tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
|
||||
br := bufio.NewReader(transform.NewReader(r, tr))
|
||||
|
||||
for {
|
||||
r, _, err := br.ReadRune()
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
@ -85,17 +84,6 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
|
||||
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
|
||||
utf16 = true
|
||||
continue
|
||||
}
|
||||
|
||||
// skip the second byte if we're reading utf16
|
||||
if utf16 && r == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
next, r, err := parseRuneForState(r, curr)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||
|
@ -103,13 +91,6 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if isNewline(r) {
|
||||
lineCount++
|
||||
linePos = 0
|
||||
} else {
|
||||
linePos++
|
||||
}
|
||||
|
||||
// process the state transition, some transitions need to be intercepted and redirected
|
||||
if next != curr {
|
||||
switch curr {
|
||||
|
@ -309,10 +290,6 @@ func isNewline(r rune) bool {
|
|||
return r == '\r' || r == '\n'
|
||||
}
|
||||
|
||||
func isUnreadable(r rune) bool {
|
||||
return r == unicode.ReplacementChar
|
||||
}
|
||||
|
||||
func isValidMessageRole(role string) bool {
|
||||
return role == "system" || role == "user" || role == "assistant"
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/encoding"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
)
|
||||
|
||||
func TestParseFileFile(t *testing.T) {
|
||||
|
@ -517,14 +519,6 @@ PARAMETER param1 1
|
|||
PARAMETER param2 4096
|
||||
SYSTEM You are a utf16 file.
|
||||
`
|
||||
// simulate a utf16 le file
|
||||
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
||||
buf := new(bytes.Buffer)
|
||||
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := ParseFile(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := []Command{
|
||||
{Name: "model", Args: "bob"},
|
||||
|
@ -533,14 +527,52 @@ SYSTEM You are a utf16 file.
|
|||
{Name: "system", Args: "You are a utf16 file."},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
t.Run("le", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
|
||||
require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
|
||||
|
||||
// simulate a utf16 be file
|
||||
buf = new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, utf16File)
|
||||
require.NoError(t, err)
|
||||
actual, err := ParseFile(&b)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err = ParseFile(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
})
|
||||
|
||||
t.Run("be", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
|
||||
require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
|
||||
|
||||
actual, err := ParseFile(&b)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseMultiByte(t *testing.T) {
|
||||
input := `FROM test
|
||||
SYSTEM 你好👋`
|
||||
|
||||
expect := []Command{
|
||||
{Name: "model", Args: "test"},
|
||||
{Name: "system", Args: "你好👋"},
|
||||
}
|
||||
|
||||
encodings := []encoding.Encoding{
|
||||
unicode.UTF8,
|
||||
unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
|
||||
unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
|
||||
}
|
||||
|
||||
for _, encoding := range encodings {
|
||||
t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
|
||||
s, err := encoding.NewEncoder().String(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := ParseFile(strings.NewReader(s))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expect, actual.Commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue