runner.go: Merge partial unicode characters before sending
We check for partial unicode characters and accumulate them before sending. However, when we did send, we still sent each individual piece separately, leading to broken output. This combines everything into a single group, which is also more efficient. This also switches to the built-in check for valid unicode characters, which is stricter. After this, we should never send back an invalid sequence. Fixes #7290
This commit is contained in:
parent
23f746508d
commit
03e40efa51
2 changed files with 38 additions and 9 deletions
|
@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnicodeOutput(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "gemma2:2b",
|
||||||
|
Prompt: "Output some smily face emoji",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
GenerateTestHelper(ctx, t, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"})
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnicodeModelDir(t *testing.T) {
|
func TestUnicodeModelDir(t *testing.T) {
|
||||||
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
|
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
|
||||||
if runtime.GOOS != "windows" {
|
if runtime.GOOS != "windows" {
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
|
@ -293,17 +294,29 @@ func (s *Server) shiftContext(seq *Sequence) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
for _, p := range seq.pendingResponses {
|
joined := strings.Join(seq.pendingResponses, "")
|
||||||
select {
|
seq.pendingResponses = []string{}
|
||||||
case seq.responses <- p:
|
|
||||||
case <-seq.quit:
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
seq.pendingResponses = []string{}
|
// We already check and queue as we are generating but some may
|
||||||
return false
|
// still make it here:
|
||||||
}
|
// - Sequence is ending, e.g. generation limit has been hit
|
||||||
|
// - Invalid characters in the middle of a string
|
||||||
|
// This is a stricter check to ensure we never output invalid Unicode.
|
||||||
|
for !utf8.ValidString(joined) {
|
||||||
|
joined = joined[:len(joined)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.pendingResponses = []string{}
|
if len(joined) == 0 {
|
||||||
return true
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case seq.responses <- joined:
|
||||||
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
|
|
Loading…
Reference in a new issue