From a103dae01eb947f08d49a8b73c6b66ad97204a19 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 1 Nov 2024 14:29:57 -0700 Subject: [PATCH] runner.go: Only allocate 1 element embedding batches for mllama Mllama has large embeddings (100 MB per image) and each embedding is represented as 1 token when passed to llama.cpp. Batches are pre- allocated for the size of the tokens times the batch size, so this results in allocations of over 50 GB at the default batch size. On some systems, these mallocs will fail. Since an image is represented as a single token and mllama doesn't support more than 1 image per request, we only need to allocate a batch size of 1, which is much more reasonable. In addition, for non-multimodal models, we don't need to allocate the embedding batches at all. Fixes #7464 --- llama/llama.go | 38 ++++++++++++++++++++++++-------------- llama/runner/image.go | 17 +++++++++++++++++ llama/runner/runner.go | 20 +++++++++++++------- 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 2fb19ae7..89943380 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -315,20 +315,30 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float type Batch struct { c C.struct_llama_batch batchSize int + maxSeq int embedSize int } -// Creates a new batch for either word tokens if embed is 0 or -// image embeddings if embed is specified. Batches cannot contain -// both types at the same time -func NewBatch(nTokens int, embed int, maxSeq int) *Batch { +// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero). +// Batches cannot contain both types at the same time. batchSize is the maximum number of entries +// that can be added per sequence +func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch { return &Batch{ - c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)), - batchSize: nTokens, - embedSize: embed, + c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)), + batchSize: batchSize, + maxSeq: maxSeq, + embedSize: embedSize, } } +func (b *Batch) Size() int { + return b.batchSize +} + +func (b *Batch) allocSize() int { + return b.batchSize * b.maxSeq +} + func (b *Batch) NumTokens() int { return int(b.c.n_tokens) } @@ -341,21 +351,21 @@ func (b *Batch) IsEmbedding() bool { // when the batch was initialized. The other argument will be ignored. Adds to the // batch with the given position for the given sequence ids, and optionally instructs // to include logits. -func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) { +func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) { if !b.IsEmbedding() { - unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token) + unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token) } else { - copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed) + copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed) } - unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos) - unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds)) + unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos) + unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds)) for i, s := range seqIds { - unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) + unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) } if logits { - unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1 + unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1 } b.c.n_tokens += 1 diff --git a/llama/runner/image.go b/llama/runner/image.go index 3b562186..ee76f47a 100644 --- a/llama/runner/image.go +++ b/llama/runner/image.go @@ -89,6 +89,23 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspect return embed } +func (c *ImageContext) BatchSize(configuredBatchSize int) int { + // If images are not supported, we don't need to allocate embedding batches + if c == nil { + return 0 + } + + // Mllama maps an image to 1 embedding token (llava creates many tokens) + // and doesn't support more than a single image per request. + // The embeddings are large (100 MB), so allocating a big batch can fail + // on some systems + if c.mllama != nil { + return 1 + } + + return configuredBatchSize +} + func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { if c != nil && c.mllama != nil { return c.mllama.EmbedSize(llamaContext) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a7e0e3b0..041bafb3 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -211,6 +211,7 @@ type Server struct { // required for image embeddings image *ImageContext + // TODO (jmorganca): make this n_batch batchSize int // parallel is the number of parallel requests to handle @@ -302,13 +303,19 @@ func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - // logically these batches are used only within the context of processBatch + // Logically these batches are used only within the context of processBatch // but it is better for performance to allocate them once here - tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) + tokenBatch := llama.NewBatch(s.batchSize, len(s.seqs), 0) defer tokenBatch.Free() - embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs)) - defer embedBatch.Free() + var embedBatch *llama.Batch + embedBatchSize := s.image.BatchSize(s.batchSize) + if embedBatchSize != 0 { + embedBatch = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc)) + defer embedBatch.Free() + } else { + embedBatch = &llama.Batch{} + } for { select { @@ -378,13 +385,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } - // todo: make this n_batch - if i >= s.batchSize { + if i >= batch.Size() { break } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs)) + batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id) seq.numPast++ numInputsProcessed++ }