From 26acdcf44e9e0c64fe0918b9cf59a61ce3339757 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 31 Oct 2024 10:55:31 -0700 Subject: [PATCH] runner.go: Don't set cross attention before sending embeddings Currently if an input has embeddings at any point then we will set cross attention to true from the beginning. This means that any tokens before the embeddings are sent will incorrectly have cross attention layers applied. This only sets cross attention when we have an embedding, either previously in this sequence or in the cache. It also makes cross attention capable of supporting parallelism at the runner level, though the mllama implementation doesn't support that yet. --- llama/runner/image.go | 11 +++++++++++ llama/runner/runner.go | 21 ++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/llama/runner/image.go b/llama/runner/image.go index d50645e8..3b562186 100644 --- a/llama/runner/image.go +++ b/llama/runner/image.go @@ -5,6 +5,7 @@ import ( "fmt" "hash/maphash" "log/slog" + "slices" "sync" "time" @@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { } } +func (c *ImageContext) NeedCrossAttention(inputs ...input) bool { + if c == nil || c.mllama == nil { + return false + } + + return slices.ContainsFunc(inputs, func(input input) bool { + return input.embed != nil + }) +} + type imageCache struct { key uint64 val [][]float32 diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a137f879..a7e0e3b0 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -52,6 +52,10 @@ type Sequence struct { // input cache being used by this sequence cache *InputCacheSlot + // does this sequence require cross-attention layers to be processed? - if we have seen + // an image for certain multi-modal models + crossAttention bool + // channel to send responses over responses chan string @@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool { func (s *Server) removeSequence(seqIndex int, reason string) { seq := s.seqs[seqIndex] - s.lc.SetCrossAttention(false) flushPending(seq) seq.doneReason = reason close(seq.responses) @@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) defer s.mu.Unlock() var batch *llama.Batch + crossAttention := false seqIdx := s.nextSeq - 1 for range s.seqs { @@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) batch = tokenBatch } else { batch = embedBatch + seq.crossAttention = s.image.NeedCrossAttention(input) } - } else if embedding != batch.IsEmbedding() { + } else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention { s.nextSeq = seqIdx break } @@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } + crossAttention = seq.crossAttention batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs)) seq.numPast++ numInputsProcessed++ @@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return } + s.lc.SetCrossAttention(crossAttention) + err := s.lc.Decode(batch) if err != nil { slog.Error("failed to decode batch", "error", err) @@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - for _, input := range seq.inputs { - if input.embed != nil { - s.lc.SetCrossAttention(true) - break - } - } - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() @@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...) + s.seqs[i] = seq s.cond.Signal() break