From 03e40efa51a75a4e8385b64996af6468f42f6c06 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 21 Oct 2024 11:07:19 -0700 Subject: [PATCH] runner.go: Merge partial unicode characters before sending We check for partial unicode characters and accumulate them before sending. However, when we did send, we still sent each individual piece separately, leading to broken output. This combines everything into a single group, which is also more efficient. This also switches to the built-in check for valid unicode characters, which is stricter. After this, we should never send back an invalid sequence. Fixes #7290 --- integration/basic_test.go | 16 ++++++++++++++++ llama/runner/runner.go | 31 ++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/integration/basic_test.go b/integration/basic_test.go index 8e35b5c5..1e3e5c58 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) { GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) } +func TestUnicodeOutput(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + // Set up the test data + req := api.GenerateRequest{ + Model: "gemma2:2b", + Prompt: "Output some smily face emoji", + Stream: &stream, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + GenerateTestHelper(ctx, t, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}) +} + func TestUnicodeModelDir(t *testing.T) { // This is only useful for Windows with utf-16 characters, so skip this test for other platforms if runtime.GOOS != "windows" { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 9fb669a2..b35704b5 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -18,6 +18,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" @@ -293,17 +294,29 @@ func (s *Server) shiftContext(seq *Sequence) { } func flushPending(seq *Sequence) bool { - for _, p := range seq.pendingResponses { - select { - case seq.responses <- p: - case <-seq.quit: - seq.pendingResponses = []string{} - return false - } + joined := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = []string{} + + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(joined) { + joined = joined[:len(joined)-1] } - seq.pendingResponses = []string{} - return true + if len(joined) == 0 { + return true + } + + select { + case seq.responses <- joined: + return true + case <-seq.quit: + return false + } } func (s *Server) removeSequence(seqIndex int, reason string) {