diff --git a/llama/llama.go b/llama/llama.go index dbb02768..72b8b691 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -600,6 +600,10 @@ func (c *Context) SetCrossAttention(state bool) { C.llama_set_cross_attention(c.c, C.bool(state)) } +func (c *Context) Synchronize() { + C.llama_synchronize(c.c) +} + // sampling // TODO: this is a temporary wrapper to allow calling C++ code from CGo type SamplingContext struct { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index cff7d148..e65bd637 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -409,6 +409,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return } + if crossAttention { + // synchronize state to ensure the cross attention batch is complete. + // needed specifically for multi-GPU systems otherwise an inflight + // task may be incorrectly invalidated causing a crash + s.lc.Synchronize() + } + for i, seq := range s.seqs { if seq == nil { continue