runner.go: Don't set cross attention before sending embeddings
Currently if an input has embeddings at any point then we will set cross attention to true from the beginning. This means that any tokens before the embeddings are sent will incorrectly have cross attention layers applied. This only sets cross attention when we have an embedding, either previously in this sequence or in the cache. It also makes cross attention capable of supporting parallelism at the runner level, though the mllama implementation doesn't support that yet.
This commit is contained in:
parent
921779bb10
commit
26acdcf44e
2 changed files with 23 additions and 9 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/maphash"
|
"hash/maphash"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
|
||||||
|
if c == nil || c.mllama == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.ContainsFunc(inputs, func(input input) bool {
|
||||||
|
return input.embed != nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type imageCache struct {
|
type imageCache struct {
|
||||||
key uint64
|
key uint64
|
||||||
val [][]float32
|
val [][]float32
|
||||||
|
|
|
@ -52,6 +52,10 @@ type Sequence struct {
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
|
// does this sequence require cross-attention layers to be processed? - if we have seen
|
||||||
|
// an image for certain multi-modal models
|
||||||
|
crossAttention bool
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan string
|
||||||
|
|
||||||
|
@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool {
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
s.lc.SetCrossAttention(false)
|
|
||||||
flushPending(seq)
|
flushPending(seq)
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
|
@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var batch *llama.Batch
|
var batch *llama.Batch
|
||||||
|
crossAttention := false
|
||||||
|
|
||||||
seqIdx := s.nextSeq - 1
|
seqIdx := s.nextSeq - 1
|
||||||
for range s.seqs {
|
for range s.seqs {
|
||||||
|
@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
batch = tokenBatch
|
batch = tokenBatch
|
||||||
} else {
|
} else {
|
||||||
batch = embedBatch
|
batch = embedBatch
|
||||||
|
seq.crossAttention = s.image.NeedCrossAttention(input)
|
||||||
}
|
}
|
||||||
} else if embedding != batch.IsEmbedding() {
|
} else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
|
||||||
s.nextSeq = seqIdx
|
s.nextSeq = seqIdx
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
crossAttention = seq.crossAttention
|
||||||
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
|
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
|
||||||
seq.numPast++
|
seq.numPast++
|
||||||
numInputsProcessed++
|
numInputsProcessed++
|
||||||
|
@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.lc.SetCrossAttention(crossAttention)
|
||||||
|
|
||||||
err := s.lc.Decode(batch)
|
err := s.lc.Decode(batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to decode batch", "error", err)
|
slog.Error("failed to decode batch", "error", err)
|
||||||
|
@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
for _, input := range seq.inputs {
|
|
||||||
if input.embed != nil {
|
|
||||||
s.lc.SetCrossAttention(true)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
|
||||||
|
|
||||||
s.seqs[i] = seq
|
s.seqs[i] = seq
|
||||||
s.cond.Signal()
|
s.cond.Signal()
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in a new issue