diff --git a/llama/llama.cpp b/llama/llama.cpp index 87d0148b..34970e54 100644 --- a/llama/llama.cpp +++ b/llama/llama.cpp @@ -2699,7 +2699,7 @@ struct llama_hparams { GGML_ABORT("fatal error"); } - bool cross_attention_layer(uint32_t il) const { + bool cross_attention_layers(uint32_t il) const { return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); } }; @@ -2731,6 +2731,9 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool no_perf; + // TODO (jmorganca): this should most likely be passed in as part of a batch + // and not set on the context for all batches. + bool cross_attn = false; enum llama_pooling_type pooling_type; @@ -3542,10 +3545,6 @@ struct llama_context { struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] - // TODO (jmorganca): this should most likely be passed in as part of a batch - // and not set on the context for all batches. - float * cross_attn_state = nullptr; - bool cross_attn_state_first_pass = true; struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] }; @@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { // for cross attention layers - if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); @@ -7389,7 +7388,7 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - if (hparams.cross_attention_layer(i)) { + if (hparams.cross_attention_layers(i)) { layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); @@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd( inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); } @@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state( const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; - struct ggml_tensor * inpCAS; - lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); - cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); - ggml_set_input(lctx.inp_cross_attn_state); - inpCAS = lctx.inp_cross_attn_state; + struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); + cb(inpCAS, "inp_cross_attn_state", -1); + ggml_set_input(inpCAS); + lctx.inp_cross_attn_state = inpCAS; return inpCAS; } @@ -10979,8 +10977,8 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - if (hparams.cross_attention_layer(il)) { - if (!lctx.cross_attn_state) { + if (hparams.cross_attention_layers(il)) { + if (!batch.embd && !cparams.cross_attn) { continue; } @@ -10991,42 +10989,28 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - cb(Qcur, "Qcur", il); - - // TODO: is this required? - Qcur = ggml_cont(ctx0, Qcur); + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); cb(Qcur, "Qcur", il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur; - if (lctx.cross_attn_state_first_pass) { + struct ggml_tensor * Kcur, * Vcur; + if (batch.embd) { Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); cb(Kcur, "Kcur", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); cb(Kcur, "Kcur", il); - Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); - cb(Kcur, "Kcur", il); - - // TODO: is this required? - Kcur = ggml_cont(ctx0, Kcur); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); cb(Kcur, "Kcur", il); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); - } else { - Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); - cb(Kcur, "Kcur (view)", il); - } - struct ggml_tensor * Vcur; - if (lctx.cross_attn_state_first_pass) { Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); cb(Vcur, "Vcur", il); @@ -11038,6 +11022,9 @@ struct llm_build_context { ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); } else { + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); + cb(Kcur, "Kcur (view)", il); + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); cb(Vcur, "Vcur (view)", il); } @@ -11045,11 +11032,8 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); cb(kq, "kq", il); - kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head))); - cb(kq, "kq_scaled", il); - // TODO: apply causal masks - struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); + struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq_soft_max, "kq_soft_max", il); Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); @@ -11139,8 +11123,8 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); if (il == n_layer - 1) { @@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (batch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = batch.n_tokens; + if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) { + ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state)); + // zero out inp_embd since it's not used + float * inp_embd_data = (float *)lctx.inp_embd->data; + for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) { + inp_embd_data[i] = 0.0f; + } + } else { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = batch.n_tokens; - ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + } } if (batch.pos && lctx.inp_pos) { @@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - // TODO (jmorganca): this might copy a lot of data on every request of a - // single generation even though it doesn't change, so we should - // find a way to not set this more than one time per image - if (lctx.inp_cross_attn_state && - lctx.inp_cross_attn_state->buffer) { - ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state)); - } - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -17789,7 +17774,7 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch_all, n_embd, + lctx.sbatch.from_batch(batch_all, batch_all.n_embd, /* simple_split */ !kv_self.recurrent, /* logits_all */ n_outputs == n_tokens_all); @@ -17899,10 +17884,6 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); - // TODO: replace with something better to find out if its - // our first actual pass - lctx.cross_attn_state_first_pass = false; - llama_graph_compute(lctx, gf, n_threads, threadpool); // update the kv ring buffer @@ -18086,7 +18067,7 @@ static int llama_encode_internal( const int64_t n_embd = hparams.n_embd; - lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); @@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model( return ctx; } -void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) { - ctx->cross_attn_state_first_pass = true; - ctx->cross_attn_state = cross_attn_state; -} - void llama_free(struct llama_context * ctx) { delete ctx; } @@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } +void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) { + ctx->cparams.cross_attn = cross_attention; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, @@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one( /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, + /*n_embd =*/ 0, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, @@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_tokens =*/ 0, /*tokens =*/ nullptr, /*embd =*/ nullptr, + /*n_embd =*/ 0, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, @@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + batch.n_embd = embd; } else { batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } diff --git a/llama/llama.go b/llama/llama.go index 7663e446..2fb19ae7 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -111,6 +111,28 @@ func PrintSystemInfo() string { return C.GoString(C.llama_print_system_info()) + compiler } +func GetModelArch(modelPath string) (string, error) { + mp := C.CString(modelPath) + defer C.free(unsafe.Pointer(mp)) + + gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)}) + if gguf_ctx == nil { + return "", errors.New("unable to load model file") + } + defer C.gguf_free(gguf_ctx) + + key := C.CString("general.architecture") + defer C.free(unsafe.Pointer(key)) + arch_index := C.gguf_find_key(gguf_ctx, key) + if int(arch_index) < 0 { + return "", errors.New("unknown model architecture") + } + + arch := C.gguf_get_val_str(gguf_ctx, arch_index) + + return C.GoString(arch), nil +} + type ContextParams struct { c C.struct_llama_context_params } @@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error { return nil } -// llava +// vision processing type ClipContext struct { - c *C.struct_clip_ctx - m *C.struct_mllama_ctx - IsMllama bool - embedPin runtime.Pinner - pinned bool + c *C.struct_clip_ctx } -func getVisionArch(mp *C.char) (string, error) { - gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)}) - if gguf_ctx == nil { - return "", errors.New("unable to load vision projector") - } - defer C.gguf_free(gguf_ctx) - - arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture")) - if int(arch_index) < 0 { - return "", errors.New("unknown vision model architecture") - } - - arch := C.gguf_get_val_str(gguf_ctx, arch_index) - - return C.GoString(arch), nil -} - -func NewClipContext(modelPath string) (*ClipContext, error) { +func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) { mp := C.CString(modelPath) defer C.free(unsafe.Pointer(mp)) + c := C.clip_model_load(mp, 1) - arch, err := getVisionArch(mp) - if err != nil { - return nil, err + projEmbedSize := int(C.clip_n_mmproj_embd(c)) + modelEmbedSize := llamaContext.Model().NEmbd() + if projEmbedSize != modelEmbedSize { + return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize) } - var cc ClipContext - if arch == "clip" { - cc.c = C.clip_model_load(mp, 1) - } else if arch == "mllama" { - cc.m = C.mllama_model_load(mp, 1) - cc.IsMllama = true - } else { - return nil, fmt.Errorf("unknown vision model architecture: %s", arch) - } - - // XXX: check embedding size? - return &cc, nil + return &ClipContext{c: c}, nil } func (c *ClipContext) Free() { - if c.c != nil { - C.clip_free(c.c) - } - if c.m != nil { - C.mllama_free(c.m) - } + C.clip_free(c.c) } -func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 { - c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data))) +func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 { + l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data))) - numTokens := int(c.n_image_pos) + numTokens := int(l.n_image_pos) numEmbed := llamaContext.Model().NEmbd() - s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens) + s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens) embed := make([][]float32, numTokens) rows := make([]float32, len(s)) @@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data [] embed[i] = rows[i*numEmbed : (i+1)*numEmbed] } - C.llava_image_embed_free(c) + C.llava_image_embed_free(l) return embed } -func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 { +type MllamaContext struct { + c *C.struct_mllama_ctx +} + +func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) { + mp := C.CString(modelPath) + defer C.free(unsafe.Pointer(mp)) + c := C.mllama_model_load(mp, 1) + + projEmbedSize := int(C.mllama_n_embd(c)) + modelEmbedSize := llamaContext.Model().NEmbd() + if projEmbedSize != modelEmbedSize { + return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize) + } + + return &MllamaContext{c: c}, nil +} + +func (m *MllamaContext) Free() { + C.mllama_free(m.c) +} + +func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 { img := C.mllama_image_init() defer C.mllama_image_free(img) C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img) - numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m)) - numEmbed := llamaContext.Model().NEmbd() + rows := make([]float32, m.EmbedSize(llamaContext)) + C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))) - rows := make([]float32, numEmbed*numTokens) - C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))) - - embed := make([][]float32, numTokens) - for i := range embed { - embed[i] = rows[i*numEmbed : (i+1)*numEmbed] - } + embed := make([][]float32, 1) + embed[0] = rows return embed } -// This really needs to be set on a batch instead -func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) { - if embed != nil { - if clipContext.pinned { - panic("Cross attention state already pinned") - } +func (m *MllamaContext) EmbedSize(llamaContext *Context) int { + numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c)) + numEmbed := llamaContext.Model().NEmbd() - embedData := &embed[0][0] - clipContext.embedPin.Pin(embedData) - clipContext.pinned = true + return numTokens * numEmbed +} - C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData))) - } else { - C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL)) - - if clipContext.pinned { - clipContext.embedPin.Unpin() - clipContext.pinned = false - } - } +func (c *Context) SetCrossAttention(state bool) { + C.llama_set_cross_attention(c.c, C.bool(state)) } // sampling diff --git a/llama/llama.h b/llama/llama.h index 5f04fc86..dea03f76 100644 --- a/llama/llama.h +++ b/llama/llama.h @@ -266,6 +266,7 @@ extern "C" { llama_token * token; float * embd; + int32_t n_embd; llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; @@ -451,7 +452,7 @@ extern "C" { // TODO (jmorganca): this should most likely be passed in as part of a batch // and not set on the context for all batches. - LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); + LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/llama/llava.cpp b/llama/llava.cpp index 9839de93..e759900e 100644 --- a/llama/llava.cpp +++ b/llama/llava.cpp @@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; + llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; if (llama_decode(ctx_llama, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/llama/patches/0010-add-mllama-support.patch b/llama/patches/0010-add-mllama-support.patch index c6dd72a7..de8e919c 100644 --- a/llama/patches/0010-add-mllama-support.patch +++ b/llama/patches/0010-add-mllama-support.patch @@ -12,27 +12,49 @@ kv cache once per run remaining is to implement the cross attention mask --- - include/llama.h | 4 + - src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++-- - 2 files changed, 447 insertions(+), 13 deletions(-) + examples/llava/llava.cpp | 2 +- + include/llama.h | 5 + + src/llama.cpp | 447 +++++++++++++++++++++++++++++++++++++-- + 3 files changed, 436 insertions(+), 18 deletions(-) +diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp +index 8558c6bd..37b2f2e2 100644 +--- a/examples/llava/llava.cpp ++++ b/examples/llava/llava.cpp +@@ -409,7 +409,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ + if (n_eval > n_batch) { + n_eval = n_batch; + } +- llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; ++ llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; + if (llama_decode(ctx_llama, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; diff --git a/include/llama.h b/include/llama.h -index 7cae1bbe..122e3cf1 100644 +index 7cae1bbe..aca09310 100644 --- a/include/llama.h +++ b/include/llama.h -@@ -423,6 +423,10 @@ extern "C" { +@@ -240,6 +240,7 @@ extern "C" { + + llama_token * token; + float * embd; ++ int32_t n_embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; +@@ -423,6 +424,10 @@ extern "C" { struct llama_model * model, struct llama_context_params params); + // TODO (jmorganca): this should most likely be passed in as part of a batch + // and not set on the context for all batches. -+ LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); ++ LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp -index 83b80b59..b189a19a 100644 +index 83b80b59..35748488 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) { @@ -160,13 +182,23 @@ index 83b80b59..b189a19a 100644 GGML_ABORT("fatal error"); } + -+ bool cross_attention_layer(uint32_t il) const { ++ bool cross_attention_layers(uint32_t il) const { + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); + } }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); -@@ -2806,6 +2859,16 @@ struct llama_layer { +@@ -2652,6 +2705,9 @@ struct llama_cparams { + bool offload_kqv; + bool flash_attn; + bool no_perf; ++ // TODO (jmorganca): this should most likely be passed in as part of a batch ++ // and not set on the context for all batches. ++ bool cross_attn = false; + + enum llama_pooling_type pooling_type; + +@@ -2806,6 +2862,16 @@ struct llama_layer { struct ggml_tensor * ffn_down_scale; struct ggml_tensor * bskcn_tv; @@ -183,25 +215,21 @@ index 83b80b59..b189a19a 100644 }; // very similar to llama_batch, -@@ -3452,6 +3515,12 @@ struct llama_context { +@@ -3452,6 +3518,8 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + -+ // TODO (jmorganca): this should most likely be passed in as part of a batch -+ // and not set on the context for all batches. -+ float * cross_attn_state = nullptr; -+ bool cross_attn_state_first_pass = true; + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] }; struct llama_lora_weight { -@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init( +@@ -3686,6 +3754,18 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { + // for cross attention layers -+ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { ++ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); @@ -215,7 +243,7 @@ index 83b80b59..b189a19a 100644 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); -@@ -5460,12 +5541,14 @@ static void llm_load_hparams( +@@ -5460,12 +5540,14 @@ static void llm_load_hparams( } // zero-out the per-layer hparams @@ -235,7 +263,7 @@ index 83b80b59..b189a19a 100644 // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; -@@ -5514,7 +5597,7 @@ static void llm_load_hparams( +@@ -5514,7 +5596,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); @@ -244,7 +272,7 @@ index 83b80b59..b189a19a 100644 if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } -@@ -5554,6 +5637,16 @@ static void llm_load_hparams( +@@ -5554,6 +5636,16 @@ static void llm_load_hparams( } } } break; @@ -261,7 +289,7 @@ index 83b80b59..b189a19a 100644 case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -@@ -7249,6 +7342,55 @@ static bool llm_load_tensors( +@@ -7249,6 +7341,55 @@ static bool llm_load_tensors( layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; @@ -286,7 +314,7 @@ index 83b80b59..b189a19a 100644 + + auto & layer = model.layers[i]; + -+ if (hparams.cross_attention_layer(i)) { ++ if (hparams.cross_attention_layers(i)) { + layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); + layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); + layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); @@ -317,7 +345,7 @@ index 83b80b59..b189a19a 100644 case LLM_ARCH_GROK: { if (n_expert == 0) { -@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam +@@ -9093,7 +9234,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && model.hparams.n_vocab != model.vocab.id_to_token.size()) { @@ -326,16 +354,7 @@ index 83b80b59..b189a19a 100644 } if (params.vocab_only) { -@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd( - - inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); - } else { -- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); -+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); - inpL = lctx.inp_embd; - ggml_set_input(lctx.inp_embd); - } -@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd( +@@ -9193,6 +9334,21 @@ static struct ggml_tensor * llm_build_inp_embd( return inpL; } @@ -346,11 +365,10 @@ index 83b80b59..b189a19a 100644 + const llm_build_cb & cb) { + const int64_t n_embd = hparams.n_embd; + -+ struct ggml_tensor * inpCAS; -+ lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); -+ cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); -+ ggml_set_input(lctx.inp_cross_attn_state); -+ inpCAS = lctx.inp_cross_attn_state; ++ struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); ++ cb(inpCAS, "inp_cross_attn_state", -1); ++ ggml_set_input(inpCAS); ++ lctx.inp_cross_attn_state = inpCAS; + + return inpCAS; +} @@ -358,7 +376,7 @@ index 83b80b59..b189a19a 100644 static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, -@@ -10167,6 +10325,7 @@ struct llm_build_context { +@@ -10167,6 +10323,7 @@ struct llm_build_context { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -366,7 +384,7 @@ index 83b80b59..b189a19a 100644 } void free() { -@@ -10754,6 +10913,253 @@ struct llm_build_context { +@@ -10754,6 +10911,239 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); @@ -410,8 +428,8 @@ index 83b80b59..b189a19a 100644 + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + -+ if (hparams.cross_attention_layer(il)) { -+ if (!lctx.cross_attn_state) { ++ if (hparams.cross_attention_layers(il)) { ++ if (!batch.embd && !cparams.cross_attn) { + continue; + } + @@ -422,42 +440,28 @@ index 83b80b59..b189a19a 100644 + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + -+ Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); -+ cb(Qcur, "Qcur", il); -+ -+ // TODO: is this required? -+ Qcur = ggml_cont(ctx0, Qcur); ++ Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur", il); + -+ struct ggml_tensor * Kcur; -+ if (lctx.cross_attn_state_first_pass) { ++ struct ggml_tensor * Kcur, * Vcur; ++ if (batch.embd) { + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); + cb(Kcur, "Kcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); + cb(Kcur, "Kcur", il); + -+ Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); -+ cb(Kcur, "Kcur", il); -+ -+ // TODO: is this required? -+ Kcur = ggml_cont(ctx0, Kcur); ++ Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); -+ } else { -+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); -+ cb(Kcur, "Kcur (view)", il); -+ } + -+ struct ggml_tensor * Vcur; -+ if (lctx.cross_attn_state_first_pass) { + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); + cb(Vcur, "Vcur", il); + @@ -469,6 +473,9 @@ index 83b80b59..b189a19a 100644 + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); + } else { ++ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); ++ cb(Kcur, "Kcur (view)", il); ++ + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); + cb(Vcur, "Vcur (view)", il); + } @@ -476,11 +483,8 @@ index 83b80b59..b189a19a 100644 + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); + cb(kq, "kq", il); + -+ kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head))); -+ cb(kq, "kq_scaled", il); -+ + // TODO: apply causal masks -+ struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); ++ struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + cb(kq_soft_max, "kq_soft_max", il); + + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); @@ -570,8 +574,8 @@ index 83b80b59..b189a19a 100644 + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, -+ model.layers[il].wo, model.layers[il].bo, -+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); ++ model.layers[il].wo, model.layers[il].bo, ++ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + + if (il == n_layer - 1) { @@ -620,7 +624,7 @@ index 83b80b59..b189a19a 100644 // lm_head cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); -@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph( +@@ -16501,6 +16891,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_llama(); } break; @@ -631,33 +635,48 @@ index 83b80b59..b189a19a 100644 case LLM_ARCH_BAICHUAN: { result = llm.build_baichuan(); -@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { - ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); +@@ -16761,10 +17155,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } -+ // TODO (jmorganca): this might copy a lot of data on every request of a -+ // single generation even though it doesn't change, so we should -+ // find a way to not set this more than one time per image -+ if (lctx.inp_cross_attn_state && -+ lctx.inp_cross_attn_state->buffer) { -+ ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state)); -+ } -+ - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); - const int64_t n_tokens = batch.n_tokens; -@@ -17455,6 +17873,10 @@ static int llama_decode_internal( + if (batch.embd) { +- const int64_t n_embd = hparams.n_embd; +- const int64_t n_tokens = batch.n_tokens; ++ if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) { ++ ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state)); ++ // zero out inp_embd since it's not used ++ float * inp_embd_data = (float *)lctx.inp_embd->data; ++ for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) { ++ inp_embd_data[i] = 0.0f; ++ } ++ } else { ++ const int64_t n_embd = hparams.n_embd; ++ const int64_t n_tokens = batch.n_tokens; - llama_set_inputs(lctx, ubatch); +- ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); ++ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); ++ } + } -+ // TODO: replace with something better to find out if its -+ // our first actual pass -+ lctx.cross_attn_state_first_pass = false; -+ - llama_graph_compute(lctx, gf, n_threads, threadpool); + if (batch.pos && lctx.inp_pos) { +@@ -17345,7 +17748,7 @@ static int llama_decode_internal( + n_outputs = 1; + } - // update the kv ring buffer -@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +- lctx.sbatch.from_batch(batch_all, n_embd, ++ lctx.sbatch.from_batch(batch_all, batch_all.n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); + +@@ -17638,7 +18041,7 @@ static int llama_encode_internal( + + const int64_t n_embd = hparams.n_embd; + +- lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); ++ lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); + +@@ -18648,7 +19051,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } @@ -668,19 +687,7 @@ index 83b80b59..b189a19a 100644 } size_t total_size_org = 0; -@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model( - return ctx; - } - -+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) { -+ ctx->cross_attn_state_first_pass = true; -+ ctx->cross_attn_state = cross_attn_state; -+} -+ - void llama_free(struct llama_context * ctx) { - delete ctx; - } -@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { +@@ -19814,6 +20219,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: @@ -688,3 +695,38 @@ index 83b80b59..b189a19a 100644 case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: +@@ -21230,6 +21636,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { + ctx->cparams.causal_attn = causal_attn; + } + ++void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) { ++ ctx->cparams.cross_attn = cross_attention; ++} ++ + struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, +@@ -21239,6 +21649,7 @@ struct llama_batch llama_batch_get_one( + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, ++ /*n_embd =*/ 0, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, +@@ -21254,6 +21665,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, ++ /*n_embd =*/ 0, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, +@@ -21265,6 +21677,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); ++ batch.n_embd = embd; + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } diff --git a/llama/runner/cache.go b/llama/runner/cache.go index ef8f6cfb..75c1d874 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -2,7 +2,6 @@ package main import ( "errors" - "hash/maphash" "log/slog" "reflect" "time" @@ -20,10 +19,6 @@ type InputCache struct { // optimize cache eviction for multiple users multiUserCache bool - // cache of images to embeddings - images []imageCache - imageHash maphash.Hash - lc *llama.Context } @@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b numCtx: kvSize / numSlots, slots: slots, multiUserCache: multiUserCache, - images: make([]imageCache, numSlots), lc: lc, } } @@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar } slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] } - -// Locking: Lookup and store operations on imageCache require a lock -// to be held that serializes these with each other. Hash does not -// require a lock nor they need to be serialized with InputCacheSlot. - -type imageCache struct { - key uint64 - val [][]float32 - lastUsed time.Time -} - -func (c *InputCache) HashImage(image []byte) uint64 { - c.imageHash.Reset() - _, _ = c.imageHash.Write(image) - return c.imageHash.Sum64() -} - -var ErrImageNotFound = errors.New("image not found in cache") - -func (c *InputCache) FindImage(hash uint64) ([][]float32, error) { - for i := range c.images { - if c.images[i].key == hash { - slog.Debug("loading image embeddings from cache", "entry", i) - c.images[i].lastUsed = time.Now() - return c.images[i].val, nil - } - } - - return nil, ErrImageNotFound -} - -func (c *InputCache) AddImage(hash uint64, embed [][]float32) { - best := time.Now() - var bestImage int - - for i := range c.images { - if c.images[i].key == hash { - bestImage = i - break - } - - if c.images[i].lastUsed.Compare(best) < 0 { - best = c.images[i].lastUsed - bestImage = i - } - } - - slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed) - c.images[bestImage].key = hash - c.images[bestImage].val = embed - c.images[bestImage].lastUsed = time.Now() -} diff --git a/llama/runner/cache_test.go b/llama/runner/cache_test.go index cc13b5f2..0e38c67d 100644 --- a/llama/runner/cache_test.go +++ b/llama/runner/cache_test.go @@ -1,7 +1,6 @@ package main import ( - "reflect" "testing" "time" ) @@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) { }) } } - -func TestImageCache(t *testing.T) { - cache := NewInputCache(nil, 2048, 4, false) - - valA := [][]float32{{0.1, 0.2}, {0.3}} - valB := [][]float32{{0.4}, {0.5}, {0.6}} - valC := [][]float32{{0.7}} - valD := [][]float32{{0.8}} - valE := [][]float32{{0.9}} - - // Empty cache - result, err := cache.FindImage(0x5adb61d31933a946) - if err != ErrImageNotFound { - t.Errorf("found result in empty cache: result %v, err %v", result, err) - } - - // Insert A - cache.AddImage(0x5adb61d31933a946, valA) - - result, err = cache.FindImage(0x5adb61d31933a946) - if !reflect.DeepEqual(result, valA) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - - // Insert B - cache.AddImage(0x011551369a34a901, valB) - - result, err = cache.FindImage(0x5adb61d31933a946) - if !reflect.DeepEqual(result, valA) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0x011551369a34a901) - if !reflect.DeepEqual(result, valB) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - - // Replace B with C - cache.AddImage(0x011551369a34a901, valC) - - result, err = cache.FindImage(0x5adb61d31933a946) - if !reflect.DeepEqual(result, valA) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0x011551369a34a901) - if !reflect.DeepEqual(result, valC) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - - // Evict A - cache.AddImage(0x756b218a517e7353, valB) - cache.AddImage(0x75e5e8d35d7e3967, valD) - cache.AddImage(0xd96f7f268ca0646e, valE) - - result, err = cache.FindImage(0x5adb61d31933a946) - if reflect.DeepEqual(result, valA) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0x756b218a517e7353) - if !reflect.DeepEqual(result, valB) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0x011551369a34a901) - if !reflect.DeepEqual(result, valC) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0x75e5e8d35d7e3967) - if !reflect.DeepEqual(result, valD) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } - result, err = cache.FindImage(0xd96f7f268ca0646e) - if !reflect.DeepEqual(result, valE) { - t.Errorf("failed to find expected value: result %v, err %v", result, err) - } -} diff --git a/llama/runner/image.go b/llama/runner/image.go new file mode 100644 index 00000000..d50645e8 --- /dev/null +++ b/llama/runner/image.go @@ -0,0 +1,145 @@ +package main + +import ( + "errors" + "fmt" + "hash/maphash" + "log/slog" + "sync" + "time" + + "github.com/ollama/ollama/llama" +) + +const imageCacheSize = 4 + +type ImageContext struct { + // mu is required to be held when generating embeddings or accessing the cache + mu sync.Mutex + + clip *llama.ClipContext + mllama *llama.MllamaContext + + // cache of images to embeddings + images []imageCache + imageHash maphash.Hash +} + +func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) { + arch, err := llama.GetModelArch(modelPath) + if err != nil { + return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath) + } + + var c ImageContext + if arch == "clip" { + c.clip, err = llama.NewClipContext(llamaContext, modelPath) + } else if arch == "mllama" { + c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath) + } else { + return nil, fmt.Errorf("unknown vision model architecture: %s", arch) + } + + if err != nil { + return nil, err + } + + c.images = make([]imageCache, imageCacheSize) + + return &c, nil +} + +func (c *ImageContext) Free(modelPath string) { + if c == nil { + return + } + + if c.clip != nil { + c.clip.Free() + } + if c.mllama != nil { + c.mllama.Free() + } +} + +func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 { + if c == nil { + return nil + } + + hash := c.hashImage(data) + + c.mu.Lock() + defer c.mu.Unlock() + + embed, err := c.findImage(hash) + if err != nil { + if c.mllama != nil { + embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId) + } else if c.clip != nil { + embed = c.clip.NewEmbed(llamaContext, data) + } else { + return nil + } + + c.addImage(hash, embed) + } + + return embed +} + +func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { + if c != nil && c.mllama != nil { + return c.mllama.EmbedSize(llamaContext) + } else { + return llamaContext.Model().NEmbd() + } +} + +type imageCache struct { + key uint64 + val [][]float32 + lastUsed time.Time +} + +func (c *ImageContext) hashImage(image []byte) uint64 { + c.imageHash.Reset() + _, _ = c.imageHash.Write(image) + return c.imageHash.Sum64() +} + +var errImageNotFound = errors.New("image not found in cache") + +func (c *ImageContext) findImage(hash uint64) ([][]float32, error) { + for i := range c.images { + if c.images[i].key == hash { + slog.Debug("loading image embeddings from cache", "entry", i) + c.images[i].lastUsed = time.Now() + return c.images[i].val, nil + } + } + + return nil, errImageNotFound +} + +func (c *ImageContext) addImage(hash uint64, embed [][]float32) { + best := time.Now() + var bestImage int + + for i := range c.images { + if c.images[i].key == hash { + bestImage = i + break + } + + if c.images[i].lastUsed.Compare(best) < 0 { + best = c.images[i].lastUsed + bestImage = i + } + } + + slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed) + c.images[bestImage].key = hash + c.images[bestImage].val = embed + c.images[bestImage].lastUsed = time.Now() +} diff --git a/llama/runner/image_test.go b/llama/runner/image_test.go new file mode 100644 index 00000000..4f1d265a --- /dev/null +++ b/llama/runner/image_test.go @@ -0,0 +1,80 @@ +package main + +import ( + "reflect" + "testing" +) + +func TestImageCache(t *testing.T) { + cache := ImageContext{images: make([]imageCache, 4)} + + valA := [][]float32{{0.1, 0.2}, {0.3}} + valB := [][]float32{{0.4}, {0.5}, {0.6}} + valC := [][]float32{{0.7}} + valD := [][]float32{{0.8}} + valE := [][]float32{{0.9}} + + // Empty cache + result, err := cache.findImage(0x5adb61d31933a946) + if err != errImageNotFound { + t.Errorf("found result in empty cache: result %v, err %v", result, err) + } + + // Insert A + cache.addImage(0x5adb61d31933a946, valA) + + result, err = cache.findImage(0x5adb61d31933a946) + if !reflect.DeepEqual(result, valA) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + + // Insert B + cache.addImage(0x011551369a34a901, valB) + + result, err = cache.findImage(0x5adb61d31933a946) + if !reflect.DeepEqual(result, valA) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0x011551369a34a901) + if !reflect.DeepEqual(result, valB) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + + // Replace B with C + cache.addImage(0x011551369a34a901, valC) + + result, err = cache.findImage(0x5adb61d31933a946) + if !reflect.DeepEqual(result, valA) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0x011551369a34a901) + if !reflect.DeepEqual(result, valC) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + + // Evict A + cache.addImage(0x756b218a517e7353, valB) + cache.addImage(0x75e5e8d35d7e3967, valD) + cache.addImage(0xd96f7f268ca0646e, valE) + + result, err = cache.findImage(0x5adb61d31933a946) + if reflect.DeepEqual(result, valA) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0x756b218a517e7353) + if !reflect.DeepEqual(result, valB) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0x011551369a34a901) + if !reflect.DeepEqual(result, valC) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0x75e5e8d35d7e3967) + if !reflect.DeepEqual(result, valD) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } + result, err = cache.findImage(0xd96f7f268ca0646e) + if !reflect.DeepEqual(result, valE) { + t.Errorf("failed to find expected value: result %v, err %v", result, err) + } +} diff --git a/llama/runner/runner.go b/llama/runner/runner.go index bbd1c0fb..a137f879 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { return nil, fmt.Errorf("invalid image index: %d", n) } - hash := s.cache.HashImage(images[imageIndex].Data) - - // Vision models cannot be accessed concurrently - s.clip.mu.Lock() - embed, err := s.cache.FindImage(hash) - if err != nil { - embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data) - s.cache.AddImage(hash, embed) - } - s.clip.mu.Unlock() - + embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID) for _, e := range embed { inputs = append(inputs, input{embed: e}) } } } - if s.clip.cc != nil { - var embed [][]float32 - - if s.clip.cc.IsMllama && len(images) >= 1 { - hash := s.cache.HashImage(images[0].Data) - - s.clip.mu.Lock() - var err error - embed, err = s.cache.FindImage(hash) - if err != nil { - embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID) - s.cache.AddImage(hash, embed) - } - s.clip.mu.Unlock() - } - s.mu.Lock() - llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed) - s.mu.Unlock() - } - return inputs, nil } -type clip struct { - cc *llama.ClipContext - mu sync.Mutex -} - type Server struct { model *llama.Model lc *llama.Context // required for image embeddings - clip clip + image *ImageContext batchSize int @@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool { func (s *Server) removeSequence(seqIndex int, reason string) { seq := s.seqs[seqIndex] + s.lc.SetCrossAttention(false) flushPending(seq) seq.doneReason = reason close(seq.responses) close(seq.embedding) seq.cache.InUse = false - if s.clip.cc != nil { - llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil) - } s.seqs[seqIndex] = nil } @@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) { tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) defer tokenBatch.Free() - embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs)) + embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs)) defer embedBatch.Free() for { @@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { + for _, input := range seq.inputs { + if input.embed != nil { + s.lc.SetCrossAttention(true) + break + } + } + seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) return } + s.seqs[i] = seq s.cond.Signal() break @@ -815,7 +786,7 @@ func (s *Server) loadModel( if ppath != "" { var err error - s.clip.cc, err = llama.NewClipContext(ppath) + s.image, err = NewImageContext(s.lc, ppath) if err != nil { panic(err) } diff --git a/server/prompt.go b/server/prompt.go index 1d6f5cdb..f91b94d8 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. currMsgIdx := n - if isMllama { - lastMsgIdx := len(msgs) - 1 - for i := lastMsgIdx; i >= currMsgIdx; i-- { - if len(msgs[i].Images) > 0 { - data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0]) + for cnt, msg := range msgs[currMsgIdx:] { + prefix := "" + imgPrompt := "" + prompt := msg.Content + + for _, i := range msg.Images { + var imgData llm.ImageData + + if isMllama { + data, aspectRatioID, err := imageproc.Preprocess(i) if err != nil { return "", nil, err } @@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. return "", nil, err } - imgData := llm.ImageData{ + imgData = llm.ImageData{ + ID: len(images), Data: buf.Bytes(), AspectRatioID: aspectRatioID, } - - msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content) - images = append(images, imgData) - break - } - } - } else { - for cnt, msg := range msgs[currMsgIdx:] { - prefix := "" - prompt := msg.Content - for _, i := range msg.Images { - imgData := llm.ImageData{ + imgPrompt = "<|image|>" + } else { + imgData = llm.ImageData{ ID: len(images), Data: i, } - - imgTag := fmt.Sprintf("[img-%d]", imgData.ID) - if !strings.Contains(prompt, "[img]") { - prefix += imgTag - } else { - prompt = strings.Replace(prompt, "[img]", imgTag, 1) - } - - images = append(images, imgData) + imgPrompt = " " } - msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt) + + imgTag := fmt.Sprintf("[img-%d]", imgData.ID) + if !strings.Contains(prompt, "[img]") { + prefix += imgTag + } else { + prompt = strings.Replace(prompt, "[img]", imgTag, 1) + } + + images = append(images, imgData) } + msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt) } // truncate any messages that do not fit into the context window diff --git a/server/prompt_test.go b/server/prompt_test.go index 123a2081..6d04db53 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}}, }, expect: expect{ - prompt: "<|image|>How many hotdogs are in this image? ", + prompt: "[img-0]<|image|>How many hotdogs are in this image? ", images: [][]byte{imgBuf}, aspectRatioID: 1, }, @@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}}, }, expect: expect{ - prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", + prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ", images: [][]byte{imgBuf}, aspectRatioID: 1, }, @@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}}, }, expect: expect{ - prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", - images: [][]byte{imgBuf2}, + prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{imgBuf, imgBuf2}, aspectRatioID: 1, }, }, @@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "Which ones have mustard?"}, }, expect: expect{ - prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ", + prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ", images: [][]byte{imgBuf}, aspectRatioID: 1, }, diff --git a/server/routes.go b/server/routes.go index eb2268c7..d5c4172a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID} + images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID} } else { images[i] = llm.ImageData{ID: i, Data: req.Images[i]} } @@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { } for _, i := range images { + imgPrompt := "" if isMllama { - msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"}) - } else { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + imgPrompt = "<|image|>" } + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)}) } values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})