runner.go: Merge partial unicode characters before sending

We check for partial unicode characters and accumulate them before
sending. However, when we did send, we still sent each individual piece
separately, leading to broken output. This combines everything into
a single group, which is also more efficient.

This also switches to the built-in check for valid unicode characters,
which is stricter. After this, we should never send back an invalid
sequence.

Fixes #7290
This commit is contained in:
Jesse Gross 2024-10-21 11:07:19 -07:00 committed by Jesse Gross
parent 23f746508d
commit 03e40efa51
2 changed files with 38 additions and 9 deletions

View file

@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) {
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}
func TestUnicodeOutput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"})
}
func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" {

View file

@ -18,6 +18,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
@ -293,19 +294,31 @@ func (s *Server) shiftContext(seq *Sequence) {
}
func flushPending(seq *Sequence) bool {
for _, p := range seq.pendingResponses {
select {
case seq.responses <- p:
case <-seq.quit:
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
return false
}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
seq.pendingResponses = []string{}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
}
func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex]