diff --git a/integration/basic_test.go b/integration/basic_test.go index 8e35b5c5..1e3e5c58 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) { GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) } +func TestUnicodeOutput(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + // Set up the test data + req := api.GenerateRequest{ + Model: "gemma2:2b", + Prompt: "Output some smily face emoji", + Stream: &stream, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + GenerateTestHelper(ctx, t, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}) +} + func TestUnicodeModelDir(t *testing.T) { // This is only useful for Windows with utf-16 characters, so skip this test for other platforms if runtime.GOOS != "windows" { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 9fb669a2..b35704b5 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -18,6 +18,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" @@ -293,17 +294,29 @@ func (s *Server) shiftContext(seq *Sequence) { } func flushPending(seq *Sequence) bool { - for _, p := range seq.pendingResponses { - select { - case seq.responses <- p: - case <-seq.quit: - seq.pendingResponses = []string{} - return false - } + joined := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = []string{} + + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(joined) { + joined = joined[:len(joined)-1] } - seq.pendingResponses = []string{} - return true + if len(joined) == 0 { + return true + } + + select { + case seq.responses <- joined: + return true + case <-seq.quit: + return false + } } func (s *Server) removeSequence(seqIndex int, reason string) {