runner.go: Only allocate 1 element embedding batches for mllama
Mllama has large embeddings (100 MB per image) and each embedding is represented as 1 token when passed to llama.cpp. Batches are pre- allocated for the size of the tokens times the batch size, so this results in allocations of over 50 GB at the default batch size. On some systems, these mallocs will fail. Since an image is represented as a single token and mllama doesn't support more than 1 image per request, we only need to allocate a batch size of 1, which is much more reasonable. In addition, for non-multimodal models, we don't need to allocate the embedding batches at all. Fixes #7464
This commit is contained in:
parent
8a9bb0d000
commit
a103dae01e
3 changed files with 54 additions and 21 deletions
|
@ -315,20 +315,30 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
|
|||
type Batch struct {
|
||||
c C.struct_llama_batch
|
||||
batchSize int
|
||||
maxSeq int
|
||||
embedSize int
|
||||
}
|
||||
|
||||
// Creates a new batch for either word tokens if embed is 0 or
|
||||
// image embeddings if embed is specified. Batches cannot contain
|
||||
// both types at the same time
|
||||
func NewBatch(nTokens int, embed int, maxSeq int) *Batch {
|
||||
// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
|
||||
// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
|
||||
// that can be added per sequence
|
||||
func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch {
|
||||
return &Batch{
|
||||
c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)),
|
||||
batchSize: nTokens,
|
||||
embedSize: embed,
|
||||
c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
|
||||
batchSize: batchSize,
|
||||
maxSeq: maxSeq,
|
||||
embedSize: embedSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batch) Size() int {
|
||||
return b.batchSize
|
||||
}
|
||||
|
||||
func (b *Batch) allocSize() int {
|
||||
return b.batchSize * b.maxSeq
|
||||
}
|
||||
|
||||
func (b *Batch) NumTokens() int {
|
||||
return int(b.c.n_tokens)
|
||||
}
|
||||
|
@ -341,21 +351,21 @@ func (b *Batch) IsEmbedding() bool {
|
|||
// when the batch was initialized. The other argument will be ignored. Adds to the
|
||||
// batch with the given position for the given sequence ids, and optionally instructs
|
||||
// to include logits.
|
||||
func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) {
|
||||
func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
|
||||
if !b.IsEmbedding() {
|
||||
unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
|
||||
unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
|
||||
} else {
|
||||
copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
|
||||
copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
|
||||
}
|
||||
unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
|
||||
unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
|
||||
unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
|
||||
unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
|
||||
|
||||
for i, s := range seqIds {
|
||||
unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
||||
unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
||||
}
|
||||
|
||||
if logits {
|
||||
unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
|
||||
unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
|
||||
}
|
||||
|
||||
b.c.n_tokens += 1
|
||||
|
|
|
@ -89,6 +89,23 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspect
|
|||
return embed
|
||||
}
|
||||
|
||||
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
|
||||
// If images are not supported, we don't need to allocate embedding batches
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Mllama maps an image to 1 embedding token (llava creates many tokens)
|
||||
// and doesn't support more than a single image per request.
|
||||
// The embeddings are large (100 MB), so allocating a big batch can fail
|
||||
// on some systems
|
||||
if c.mllama != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
return configuredBatchSize
|
||||
}
|
||||
|
||||
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
||||
if c != nil && c.mllama != nil {
|
||||
return c.mllama.EmbedSize(llamaContext)
|
||||
|
|
|
@ -211,6 +211,7 @@ type Server struct {
|
|||
// required for image embeddings
|
||||
image *ImageContext
|
||||
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// parallel is the number of parallel requests to handle
|
||||
|
@ -302,13 +303,19 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
|||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
// logically these batches are used only within the context of processBatch
|
||||
// Logically these batches are used only within the context of processBatch
|
||||
// but it is better for performance to allocate them once here
|
||||
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
||||
tokenBatch := llama.NewBatch(s.batchSize, len(s.seqs), 0)
|
||||
defer tokenBatch.Free()
|
||||
|
||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
|
||||
defer embedBatch.Free()
|
||||
var embedBatch *llama.Batch
|
||||
embedBatchSize := s.image.BatchSize(s.batchSize)
|
||||
if embedBatchSize != 0 {
|
||||
embedBatch = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc))
|
||||
defer embedBatch.Free()
|
||||
} else {
|
||||
embedBatch = &llama.Batch{}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -378,13 +385,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
break
|
||||
}
|
||||
|
||||
// todo: make this n_batch
|
||||
if i >= s.batchSize {
|
||||
if i >= batch.Size() {
|
||||
break
|
||||
}
|
||||
|
||||
crossAttention = seq.crossAttention
|
||||
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
|
||||
batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.numPast++
|
||||
numInputsProcessed++
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue