diff --git a/llama/llama.go b/llama/llama.go index 72b8b691..a3fbbc1d 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -384,6 +384,8 @@ func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ... if logits { unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1 + } else { + unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0 } b.c.n_tokens += 1