fix(mllama): sync backend between batches
This commit is contained in:
parent
c2e8cbaa14
commit
5b3393b6a2
2 changed files with 11 additions and 0 deletions
|
@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) {
|
||||||
C.llama_set_cross_attention(c.c, C.bool(state))
|
C.llama_set_cross_attention(c.c, C.bool(state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) Synchronize() {
|
||||||
|
C.llama_synchronize(c.c)
|
||||||
|
}
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||||
type SamplingContext struct {
|
type SamplingContext struct {
|
||||||
|
|
|
@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if crossAttention {
|
||||||
|
// synchronize state to ensure the cross attention batch is complete.
|
||||||
|
// needed specifically for multi-GPU systems otherwise an inflight
|
||||||
|
// task may be incorrectly invalidated causing a crash
|
||||||
|
s.lc.Synchronize()
|
||||||
|
}
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in a new issue