From 5b3393b6a2920c4f410ee636777533c77752106e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 13 Nov 2024 14:12:30 -0800 Subject: [PATCH] fix(mllama): sync backend between batches --- llama/llama.go | 4 ++++ llama/runner/runner.go | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/llama/llama.go b/llama/llama.go index a092ea12..df06f0f6 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) { C.llama_set_cross_attention(c.c, C.bool(state)) } +func (c *Context) Synchronize() { + C.llama_synchronize(c.c) +} + // sampling // TODO: this is a temporary wrapper to allow calling C++ code from CGo type SamplingContext struct { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0a37dee0..637dd9cc 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return } + if crossAttention { + // synchronize state to ensure the cross attention batch is complete. + // needed specifically for multi-GPU systems otherwise an inflight + // task may be incorrectly invalidated causing a crash + s.lc.Synchronize() + } + for i, seq := range s.seqs { if seq == nil { continue