runner.go: Better abstract vision model integration

-Update mllama to take the cross attention state as embeddings in
a batch, more similar to how Llava handles it. This improves
integration with the input cache.
-Pass locations in a prompt for embeddings using tags similar to Llava.
-Abstract interface to vision models so the main runner accesses Clip
and Mllama similarly

Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
Jesse Gross 2024-10-11 15:34:01 -07:00 committed by Jesse Gross
parent 712e99d477
commit c826e57475
13 changed files with 534 additions and 454 deletions

105
llama/llama.cpp vendored
View file

@ -2699,7 +2699,7 @@ struct llama_hparams {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
bool cross_attention_layer(uint32_t il) const { bool cross_attention_layers(uint32_t il) const {
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
} }
}; };
@ -2731,6 +2731,9 @@ struct llama_cparams {
bool offload_kqv; bool offload_kqv;
bool flash_attn; bool flash_attn;
bool no_perf; bool no_perf;
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
bool cross_attn = false;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
@ -3542,10 +3545,6 @@ struct llama_context {
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
float * cross_attn_state = nullptr;
bool cross_attn_state_first_pass = true;
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
}; };
@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) { for (int i = 0; i < (int) n_layer; i++) {
// for cross attention layers // for cross attention layers
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
@ -7389,7 +7388,7 @@ static bool llm_load_tensors(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
if (hparams.cross_attention_layer(i)) { if (hparams.cross_attention_layers(i)) {
layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd(
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
} else { } else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
inpL = lctx.inp_embd; inpL = lctx.inp_embd;
ggml_set_input(lctx.inp_embd); ggml_set_input(lctx.inp_embd);
} }
@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state(
const llm_build_cb & cb) { const llm_build_cb & cb) {
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
struct ggml_tensor * inpCAS; struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); cb(inpCAS, "inp_cross_attn_state", -1);
cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); ggml_set_input(inpCAS);
ggml_set_input(lctx.inp_cross_attn_state); lctx.inp_cross_attn_state = inpCAS;
inpCAS = lctx.inp_cross_attn_state;
return inpCAS; return inpCAS;
} }
@ -10979,8 +10977,8 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
if (hparams.cross_attention_layer(il)) { if (hparams.cross_attention_layers(il)) {
if (!lctx.cross_attn_state) { if (!batch.embd && !cparams.cross_attn) {
continue; continue;
} }
@ -10991,42 +10989,28 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
cb(Qcur, "Qcur", il);
// TODO: is this required?
Qcur = ggml_cont(ctx0, Qcur);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur; struct ggml_tensor * Kcur, * Vcur;
if (lctx.cross_attn_state_first_pass) { if (batch.embd) {
Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
cb(Kcur, "Kcur", il);
// TODO: is this required?
Kcur = ggml_cont(ctx0, Kcur);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
} else {
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
cb(Kcur, "Kcur (view)", il);
}
struct ggml_tensor * Vcur;
if (lctx.cross_attn_state_first_pass) {
Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
@ -11038,6 +11022,9 @@ struct llm_build_context {
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
} else { } else {
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
cb(Kcur, "Kcur (view)", il);
Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
cb(Vcur, "Vcur (view)", il); cb(Vcur, "Vcur (view)", il);
} }
@ -11045,11 +11032,8 @@ struct llm_build_context {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
cb(kq, "kq", il); cb(kq, "kq", il);
kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
cb(kq, "kq_scaled", il);
// TODO: apply causal masks // TODO: apply causal masks
struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
cb(kq_soft_max, "kq_soft_max", il); cb(kq_soft_max, "kq_soft_max", il);
Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
@ -11139,8 +11123,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
if (il == n_layer - 1) { if (il == n_layer - 1) {
@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
} }
if (batch.embd) { if (batch.embd) {
const int64_t n_embd = hparams.n_embd; if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
// zero out inp_embd since it's not used
float * inp_embd_data = (float *)lctx.inp_embd->data;
for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
inp_embd_data[i] = 0.0f;
}
} else {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
} }
if (batch.pos && lctx.inp_pos) { if (batch.pos && lctx.inp_pos) {
@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
// TODO (jmorganca): this might copy a lot of data on every request of a
// single generation even though it doesn't change, so we should
// find a way to not set this more than one time per image
if (lctx.inp_cross_attn_state &&
lctx.inp_cross_attn_state->buffer) {
ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
@ -17789,7 +17774,7 @@ static int llama_decode_internal(
n_outputs = 1; n_outputs = 1;
} }
lctx.sbatch.from_batch(batch_all, n_embd, lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
/* simple_split */ !kv_self.recurrent, /* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all); /* logits_all */ n_outputs == n_tokens_all);
@ -17899,10 +17884,6 @@ static int llama_decode_internal(
llama_set_inputs(lctx, ubatch); llama_set_inputs(lctx, ubatch);
// TODO: replace with something better to find out if its
// our first actual pass
lctx.cross_attn_state_first_pass = false;
llama_graph_compute(lctx, gf, n_threads, threadpool); llama_graph_compute(lctx, gf, n_threads, threadpool);
// update the kv ring buffer // update the kv ring buffer
@ -18086,7 +18067,7 @@ static int llama_encode_internal(
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model(
return ctx; return ctx;
} }
void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
ctx->cross_attn_state_first_pass = true;
ctx->cross_attn_state = cross_attn_state;
}
void llama_free(struct llama_context * ctx) { void llama_free(struct llama_context * ctx) {
delete ctx; delete ctx;
} }
@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn; ctx->cparams.causal_attn = causal_attn;
} }
void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
ctx->cparams.cross_attn = cross_attention;
}
struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,
@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ tokens, /*tokens =*/ tokens,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0, /*n_tokens =*/ 0,
/*tokens =*/ nullptr, /*tokens =*/ nullptr,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) { if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else { } else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
} }

View file

@ -111,6 +111,28 @@ func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info()) + compiler return C.GoString(C.llama_print_system_info()) + compiler
} }
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load model file")
}
defer C.gguf_free(gguf_ctx)
key := C.CString("general.architecture")
defer C.free(unsafe.Pointer(key))
arch_index := C.gguf_find_key(gguf_ctx, key)
if int(arch_index) < 0 {
return "", errors.New("unknown model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
}
type ContextParams struct { type ContextParams struct {
c C.struct_llama_context_params c C.struct_llama_context_params
} }
@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error {
return nil return nil
} }
// llava // vision processing
type ClipContext struct { type ClipContext struct {
c *C.struct_clip_ctx c *C.struct_clip_ctx
m *C.struct_mllama_ctx
IsMllama bool
embedPin runtime.Pinner
pinned bool
} }
func getVisionArch(mp *C.char) (string, error) { func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load vision projector")
}
defer C.gguf_free(gguf_ctx)
arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
if int(arch_index) < 0 {
return "", errors.New("unknown vision model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
}
func NewClipContext(modelPath string) (*ClipContext, error) {
mp := C.CString(modelPath) mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp)) defer C.free(unsafe.Pointer(mp))
c := C.clip_model_load(mp, 1)
arch, err := getVisionArch(mp) projEmbedSize := int(C.clip_n_mmproj_embd(c))
if err != nil { modelEmbedSize := llamaContext.Model().NEmbd()
return nil, err if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
} }
var cc ClipContext return &ClipContext{c: c}, nil
if arch == "clip" {
cc.c = C.clip_model_load(mp, 1)
} else if arch == "mllama" {
cc.m = C.mllama_model_load(mp, 1)
cc.IsMllama = true
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
// XXX: check embedding size?
return &cc, nil
} }
func (c *ClipContext) Free() { func (c *ClipContext) Free() {
if c.c != nil { C.clip_free(c.c)
C.clip_free(c.c)
}
if c.m != nil {
C.mllama_free(c.m)
}
} }
func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 { func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data))) l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
numTokens := int(c.n_image_pos) numTokens := int(l.n_image_pos)
numEmbed := llamaContext.Model().NEmbd() numEmbed := llamaContext.Model().NEmbd()
s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens) s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
embed := make([][]float32, numTokens) embed := make([][]float32, numTokens)
rows := make([]float32, len(s)) rows := make([]float32, len(s))
@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
embed[i] = rows[i*numEmbed : (i+1)*numEmbed] embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
} }
C.llava_image_embed_free(c) C.llava_image_embed_free(l)
return embed return embed
} }
func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 { type MllamaContext struct {
c *C.struct_mllama_ctx
}
func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
c := C.mllama_model_load(mp, 1)
projEmbedSize := int(C.mllama_n_embd(c))
modelEmbedSize := llamaContext.Model().NEmbd()
if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
}
return &MllamaContext{c: c}, nil
}
func (m *MllamaContext) Free() {
C.mllama_free(m.c)
}
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
img := C.mllama_image_init() img := C.mllama_image_init()
defer C.mllama_image_free(img) defer C.mllama_image_free(img)
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img) C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m)) rows := make([]float32, m.EmbedSize(llamaContext))
numEmbed := llamaContext.Model().NEmbd() C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
rows := make([]float32, numEmbed*numTokens) embed := make([][]float32, 1)
C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))) embed[0] = rows
embed := make([][]float32, numTokens)
for i := range embed {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
return embed return embed
} }
// This really needs to be set on a batch instead func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) { numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
if embed != nil { numEmbed := llamaContext.Model().NEmbd()
if clipContext.pinned {
panic("Cross attention state already pinned")
}
embedData := &embed[0][0] return numTokens * numEmbed
clipContext.embedPin.Pin(embedData) }
clipContext.pinned = true
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData))) func (c *Context) SetCrossAttention(state bool) {
} else { C.llama_set_cross_attention(c.c, C.bool(state))
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
if clipContext.pinned {
clipContext.embedPin.Unpin()
clipContext.pinned = false
}
}
} }
// sampling // sampling

3
llama/llama.h vendored
View file

@ -266,6 +266,7 @@ extern "C" {
llama_token * token; llama_token * token;
float * embd; float * embd;
int32_t n_embd;
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
@ -451,7 +452,7 @@ extern "C" {
// TODO (jmorganca): this should most likely be passed in as part of a batch // TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches. // and not set on the context for all batches.
LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);

2
llama/llava.cpp vendored
View file

@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) { if (llama_decode(ctx_llama, batch)) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;

View file

@ -12,27 +12,49 @@ kv cache once per run
remaining is to implement the cross attention mask remaining is to implement the cross attention mask
--- ---
include/llama.h | 4 + examples/llava/llava.cpp | 2 +-
src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++-- include/llama.h | 5 +
2 files changed, 447 insertions(+), 13 deletions(-) src/llama.cpp | 447 +++++++++++++++++++++++++++++++++++++--
3 files changed, 436 insertions(+), 18 deletions(-)
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
index 8558c6bd..37b2f2e2 100644
--- a/examples/llava/llava.cpp
+++ b/examples/llava/llava.cpp
@@ -409,7 +409,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) {
n_eval = n_batch;
}
- llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
+ llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
diff --git a/include/llama.h b/include/llama.h diff --git a/include/llama.h b/include/llama.h
index 7cae1bbe..122e3cf1 100644 index 7cae1bbe..aca09310 100644
--- a/include/llama.h --- a/include/llama.h
+++ b/include/llama.h +++ b/include/llama.h
@@ -423,6 +423,10 @@ extern "C" { @@ -240,6 +240,7 @@ extern "C" {
llama_token * token;
float * embd;
+ int32_t n_embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
@@ -423,6 +424,10 @@ extern "C" {
struct llama_model * model, struct llama_model * model,
struct llama_context_params params); struct llama_context_params params);
+ // TODO (jmorganca): this should most likely be passed in as part of a batch + // TODO (jmorganca): this should most likely be passed in as part of a batch
+ // and not set on the context for all batches. + // and not set on the context for all batches.
+ LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); + LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
+ +
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
diff --git a/src/llama.cpp b/src/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index 83b80b59..b189a19a 100644 index 83b80b59..35748488 100644
--- a/src/llama.cpp --- a/src/llama.cpp
+++ b/src/llama.cpp +++ b/src/llama.cpp
@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) { @@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
@ -160,13 +182,23 @@ index 83b80b59..b189a19a 100644
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
+ +
+ bool cross_attention_layer(uint32_t il) const { + bool cross_attention_layers(uint32_t il) const {
+ return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
+ } + }
}; };
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -2806,6 +2859,16 @@ struct llama_layer { @@ -2652,6 +2705,9 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
+ // and not set on the context for all batches.
+ bool cross_attn = false;
enum llama_pooling_type pooling_type;
@@ -2806,6 +2862,16 @@ struct llama_layer {
struct ggml_tensor * ffn_down_scale; struct ggml_tensor * ffn_down_scale;
struct ggml_tensor * bskcn_tv; struct ggml_tensor * bskcn_tv;
@ -183,25 +215,21 @@ index 83b80b59..b189a19a 100644
}; };
// very similar to llama_batch, // very similar to llama_batch,
@@ -3452,6 +3515,12 @@ struct llama_context { @@ -3452,6 +3518,8 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
+ +
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
+ // and not set on the context for all batches.
+ float * cross_attn_state = nullptr;
+ bool cross_attn_state_first_pass = true;
+ struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
}; };
struct llama_lora_weight { struct llama_lora_weight {
@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init( @@ -3686,6 +3754,18 @@ static bool llama_kv_cache_init(
cache.v_l.reserve(n_layer); cache.v_l.reserve(n_layer);
for (int i = 0; i < (int) n_layer; i++) { for (int i = 0; i < (int) n_layer; i++) {
+ // for cross attention layers + // for cross attention layers
+ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
+ struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
+ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
+ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
@ -215,7 +243,7 @@ index 83b80b59..b189a19a 100644
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -5460,12 +5541,14 @@ static void llm_load_hparams( @@ -5460,12 +5540,14 @@ static void llm_load_hparams(
} }
// zero-out the per-layer hparams // zero-out the per-layer hparams
@ -235,7 +263,7 @@ index 83b80b59..b189a19a 100644
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv_arr = hparams.n_head_arr; hparams.n_head_kv_arr = hparams.n_head_arr;
@@ -5514,7 +5597,7 @@ static void llm_load_hparams( @@ -5514,7 +5596,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
@ -244,7 +272,7 @@ index 83b80b59..b189a19a 100644
if (hparams.n_rot != hparams.n_embd_head_k) { if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
} }
@@ -5554,6 +5637,16 @@ static void llm_load_hparams( @@ -5554,6 +5636,16 @@ static void llm_load_hparams(
} }
} }
} break; } break;
@ -261,7 +289,7 @@ index 83b80b59..b189a19a 100644
case LLM_ARCH_MINICPM: case LLM_ARCH_MINICPM:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -7249,6 +7342,55 @@ static bool llm_load_tensors( @@ -7249,6 +7341,55 @@ static bool llm_load_tensors(
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
} }
} break; } break;
@ -286,7 +314,7 @@ index 83b80b59..b189a19a 100644
+ +
+ auto & layer = model.layers[i]; + auto & layer = model.layers[i];
+ +
+ if (hparams.cross_attention_layer(i)) { + if (hparams.cross_attention_layers(i)) {
+ layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); + layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
+ layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); + layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
+ layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); + layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
@ -317,7 +345,7 @@ index 83b80b59..b189a19a 100644
case LLM_ARCH_GROK: case LLM_ARCH_GROK:
{ {
if (n_expert == 0) { if (n_expert == 0) {
@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam @@ -9093,7 +9234,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
model.hparams.n_vocab != model.vocab.id_to_token.size()) { model.hparams.n_vocab != model.vocab.id_to_token.size()) {
@ -326,16 +354,7 @@ index 83b80b59..b189a19a 100644
} }
if (params.vocab_only) { if (params.vocab_only) {
@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd( @@ -9193,6 +9334,21 @@ static struct ggml_tensor * llm_build_inp_embd(
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
} else {
- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
inpL = lctx.inp_embd;
ggml_set_input(lctx.inp_embd);
}
@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
return inpL; return inpL;
} }
@ -346,11 +365,10 @@ index 83b80b59..b189a19a 100644
+ const llm_build_cb & cb) { + const llm_build_cb & cb) {
+ const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd;
+ +
+ struct ggml_tensor * inpCAS; + struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
+ lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); + cb(inpCAS, "inp_cross_attn_state", -1);
+ cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); + ggml_set_input(inpCAS);
+ ggml_set_input(lctx.inp_cross_attn_state); + lctx.inp_cross_attn_state = inpCAS;
+ inpCAS = lctx.inp_cross_attn_state;
+ +
+ return inpCAS; + return inpCAS;
+} +}
@ -358,7 +376,7 @@ index 83b80b59..b189a19a 100644
static void llm_build_kv_store( static void llm_build_kv_store(
struct ggml_context * ctx, struct ggml_context * ctx,
const llama_hparams & hparams, const llama_hparams & hparams,
@@ -10167,6 +10325,7 @@ struct llm_build_context { @@ -10167,6 +10323,7 @@ struct llm_build_context {
lctx.inp_pos_bucket = nullptr; lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr; lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr; lctx.inp_KQ_mask_cross = nullptr;
@ -366,7 +384,7 @@ index 83b80b59..b189a19a 100644
} }
void free() { void free() {
@@ -10754,6 +10913,253 @@ struct llm_build_context { @@ -10754,6 +10911,239 @@ struct llm_build_context {
LLM_NORM_RMS, cb, -1); LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1); cb(cur, "result_norm", -1);
@ -410,8 +428,8 @@ index 83b80b59..b189a19a 100644
+ LLM_NORM_RMS, cb, il); + LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il); + cb(cur, "attn_norm", il);
+ +
+ if (hparams.cross_attention_layer(il)) { + if (hparams.cross_attention_layers(il)) {
+ if (!lctx.cross_attn_state) { + if (!batch.embd && !cparams.cross_attn) {
+ continue; + continue;
+ } + }
+ +
@ -422,42 +440,28 @@ index 83b80b59..b189a19a 100644
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ cb(Qcur, "Qcur", il); + cb(Qcur, "Qcur", il);
+ +
+ Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
+ cb(Qcur, "Qcur", il);
+
+ // TODO: is this required?
+ Qcur = ggml_cont(ctx0, Qcur);
+ cb(Qcur, "Qcur", il); + cb(Qcur, "Qcur", il);
+ +
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(Qcur, "Qcur", il); + cb(Qcur, "Qcur", il);
+ +
+ struct ggml_tensor * Kcur; + struct ggml_tensor * Kcur, * Vcur;
+ if (lctx.cross_attn_state_first_pass) { + if (batch.embd) {
+ Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
+ cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il);
+ +
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
+ cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il);
+ +
+ Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+ cb(Kcur, "Kcur", il);
+
+ // TODO: is this required?
+ Kcur = ggml_cont(ctx0, Kcur);
+ cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il);
+ +
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il);
+ +
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
+ } else {
+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
+ cb(Kcur, "Kcur (view)", il);
+ }
+ +
+ struct ggml_tensor * Vcur;
+ if (lctx.cross_attn_state_first_pass) {
+ Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
+ cb(Vcur, "Vcur", il); + cb(Vcur, "Vcur", il);
+ +
@ -469,6 +473,9 @@ index 83b80b59..b189a19a 100644
+ +
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
+ } else { + } else {
+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
+ cb(Kcur, "Kcur (view)", il);
+
+ Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
+ cb(Vcur, "Vcur (view)", il); + cb(Vcur, "Vcur (view)", il);
+ } + }
@ -476,11 +483,8 @@ index 83b80b59..b189a19a 100644
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
+ cb(kq, "kq", il); + cb(kq, "kq", il);
+ +
+ kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
+ cb(kq, "kq_scaled", il);
+
+ // TODO: apply causal masks + // TODO: apply causal masks
+ struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); + struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
+ cb(kq_soft_max, "kq_soft_max", il); + cb(kq_soft_max, "kq_soft_max", il);
+ +
+ Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
@ -570,8 +574,8 @@ index 83b80b59..b189a19a 100644
+ cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il);
+ +
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ +
+ +
+ if (il == n_layer - 1) { + if (il == n_layer - 1) {
@ -620,7 +624,7 @@ index 83b80b59..b189a19a 100644
// lm_head // lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph( @@ -16501,6 +16891,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_llama(); result = llm.build_llama();
} break; } break;
@ -631,33 +635,48 @@ index 83b80b59..b189a19a 100644
case LLM_ARCH_BAICHUAN: case LLM_ARCH_BAICHUAN:
{ {
result = llm.build_baichuan(); result = llm.build_baichuan();
@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { @@ -16761,10 +17155,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
+ // TODO (jmorganca): this might copy a lot of data on every request of a if (batch.embd) {
+ // single generation even though it doesn't change, so we should - const int64_t n_embd = hparams.n_embd;
+ // find a way to not set this more than one time per image - const int64_t n_tokens = batch.n_tokens;
+ if (lctx.inp_cross_attn_state && + if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
+ lctx.inp_cross_attn_state->buffer) { + ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
+ ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state)); + // zero out inp_embd since it's not used
+ } + float * inp_embd_data = (float *)lctx.inp_embd->data;
+ + for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + inp_embd_data[i] = 0.0f;
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + }
const int64_t n_tokens = batch.n_tokens; + } else {
@@ -17455,6 +17873,10 @@ static int llama_decode_internal( + const int64_t n_embd = hparams.n_embd;
+ const int64_t n_tokens = batch.n_tokens;
llama_set_inputs(lctx, ubatch); - ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+ }
}
+ // TODO: replace with something better to find out if its if (batch.pos && lctx.inp_pos) {
+ // our first actual pass @@ -17345,7 +17748,7 @@ static int llama_decode_internal(
+ lctx.cross_attn_state_first_pass = false; n_outputs = 1;
+ }
llama_graph_compute(lctx, gf, n_threads, threadpool);
// update the kv ring buffer - lctx.sbatch.from_batch(batch_all, n_embd,
@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s + lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
@@ -17638,7 +18041,7 @@ static int llama_encode_internal(
const int64_t n_embd = hparams.n_embd;
- lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+ lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
@@ -18648,7 +19051,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (llama_model_has_encoder(&model)) { if (llama_model_has_encoder(&model)) {
n_attn_layer *= 3; n_attn_layer *= 3;
} }
@ -668,19 +687,7 @@ index 83b80b59..b189a19a 100644
} }
size_t total_size_org = 0; size_t total_size_org = 0;
@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model( @@ -19814,6 +20219,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
return ctx;
}
+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
+ ctx->cross_attn_state_first_pass = true;
+ ctx->cross_attn_state = cross_attn_state;
+}
+
void llama_free(struct llama_context * ctx) {
delete ctx;
}
@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values // use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
@ -688,3 +695,38 @@ index 83b80b59..b189a19a 100644
case LLM_ARCH_BAICHUAN: case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER: case LLM_ARCH_STARCODER:
case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO:
@@ -21230,6 +21636,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn;
}
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
+ ctx->cparams.cross_attn = cross_attention;
+}
+
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
@@ -21239,6 +21649,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
+ /*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@@ -21254,6 +21665,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
+ /*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@@ -21265,6 +21677,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+ batch.n_embd = embd;
} else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}

View file

@ -2,7 +2,6 @@ package main
import ( import (
"errors" "errors"
"hash/maphash"
"log/slog" "log/slog"
"reflect" "reflect"
"time" "time"
@ -20,10 +19,6 @@ type InputCache struct {
// optimize cache eviction for multiple users // optimize cache eviction for multiple users
multiUserCache bool multiUserCache bool
// cache of images to embeddings
images []imageCache
imageHash maphash.Hash
lc *llama.Context lc *llama.Context
} }
@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
numCtx: kvSize / numSlots, numCtx: kvSize / numSlots,
slots: slots, slots: slots,
multiUserCache: multiUserCache, multiUserCache: multiUserCache,
images: make([]imageCache, numSlots),
lc: lc, lc: lc,
} }
} }
@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar
} }
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
} }
// Locking: Lookup and store operations on imageCache require a lock
// to be held that serializes these with each other. Hash does not
// require a lock nor they need to be serialized with InputCacheSlot.
type imageCache struct {
key uint64
val [][]float32
lastUsed time.Time
}
func (c *InputCache) HashImage(image []byte) uint64 {
c.imageHash.Reset()
_, _ = c.imageHash.Write(image)
return c.imageHash.Sum64()
}
var ErrImageNotFound = errors.New("image not found in cache")
func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
c.images[i].lastUsed = time.Now()
return c.images[i].val, nil
}
}
return nil, ErrImageNotFound
}
func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
best := time.Now()
var bestImage int
for i := range c.images {
if c.images[i].key == hash {
bestImage = i
break
}
if c.images[i].lastUsed.Compare(best) < 0 {
best = c.images[i].lastUsed
bestImage = i
}
}
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
c.images[bestImage].key = hash
c.images[bestImage].val = embed
c.images[bestImage].lastUsed = time.Now()
}

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"reflect"
"testing" "testing"
"time" "time"
) )
@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
}) })
} }
} }
func TestImageCache(t *testing.T) {
cache := NewInputCache(nil, 2048, 4, false)
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.FindImage(0x5adb61d31933a946)
if err != ErrImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.AddImage(0x5adb61d31933a946, valA)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.AddImage(0x011551369a34a901, valB)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.AddImage(0x011551369a34a901, valC)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.AddImage(0x756b218a517e7353, valB)
cache.AddImage(0x75e5e8d35d7e3967, valD)
cache.AddImage(0xd96f7f268ca0646e, valE)
result, err = cache.FindImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}

145
llama/runner/image.go Normal file
View file

@ -0,0 +1,145 @@
package main
import (
"errors"
"fmt"
"hash/maphash"
"log/slog"
"sync"
"time"
"github.com/ollama/ollama/llama"
)
const imageCacheSize = 4
type ImageContext struct {
// mu is required to be held when generating embeddings or accessing the cache
mu sync.Mutex
clip *llama.ClipContext
mllama *llama.MllamaContext
// cache of images to embeddings
images []imageCache
imageHash maphash.Hash
}
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
arch, err := llama.GetModelArch(modelPath)
if err != nil {
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
}
var c ImageContext
if arch == "clip" {
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
} else if arch == "mllama" {
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
if err != nil {
return nil, err
}
c.images = make([]imageCache, imageCacheSize)
return &c, nil
}
func (c *ImageContext) Free(modelPath string) {
if c == nil {
return
}
if c.clip != nil {
c.clip.Free()
}
if c.mllama != nil {
c.mllama.Free()
}
}
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
if c == nil {
return nil
}
hash := c.hashImage(data)
c.mu.Lock()
defer c.mu.Unlock()
embed, err := c.findImage(hash)
if err != nil {
if c.mllama != nil {
embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
} else if c.clip != nil {
embed = c.clip.NewEmbed(llamaContext, data)
} else {
return nil
}
c.addImage(hash, embed)
}
return embed
}
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
if c != nil && c.mllama != nil {
return c.mllama.EmbedSize(llamaContext)
} else {
return llamaContext.Model().NEmbd()
}
}
type imageCache struct {
key uint64
val [][]float32
lastUsed time.Time
}
func (c *ImageContext) hashImage(image []byte) uint64 {
c.imageHash.Reset()
_, _ = c.imageHash.Write(image)
return c.imageHash.Sum64()
}
var errImageNotFound = errors.New("image not found in cache")
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
c.images[i].lastUsed = time.Now()
return c.images[i].val, nil
}
}
return nil, errImageNotFound
}
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
best := time.Now()
var bestImage int
for i := range c.images {
if c.images[i].key == hash {
bestImage = i
break
}
if c.images[i].lastUsed.Compare(best) < 0 {
best = c.images[i].lastUsed
bestImage = i
}
}
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
c.images[bestImage].key = hash
c.images[bestImage].val = embed
c.images[bestImage].lastUsed = time.Now()
}

View file

@ -0,0 +1,80 @@
package main
import (
"reflect"
"testing"
)
func TestImageCache(t *testing.T) {
cache := ImageContext{images: make([]imageCache, 4)}
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.addImage(0x5adb61d31933a946, valA)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.addImage(0x011551369a34a901, valB)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.addImage(0x011551369a34a901, valC)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.addImage(0x756b218a517e7353, valB)
cache.addImage(0x75e5e8d35d7e3967, valD)
cache.addImage(0xd96f7f268ca0646e, valE)
result, err = cache.findImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}

View file

@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
hash := s.cache.HashImage(images[imageIndex].Data) embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
// Vision models cannot be accessed concurrently
s.clip.mu.Lock()
embed, err := s.cache.FindImage(hash)
if err != nil {
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
s.cache.AddImage(hash, embed)
}
s.clip.mu.Unlock()
for _, e := range embed { for _, e := range embed {
inputs = append(inputs, input{embed: e}) inputs = append(inputs, input{embed: e})
} }
} }
} }
if s.clip.cc != nil {
var embed [][]float32
if s.clip.cc.IsMllama && len(images) >= 1 {
hash := s.cache.HashImage(images[0].Data)
s.clip.mu.Lock()
var err error
embed, err = s.cache.FindImage(hash)
if err != nil {
embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
s.cache.AddImage(hash, embed)
}
s.clip.mu.Unlock()
}
s.mu.Lock()
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
s.mu.Unlock()
}
return inputs, nil return inputs, nil
} }
type clip struct {
cc *llama.ClipContext
mu sync.Mutex
}
type Server struct { type Server struct {
model *llama.Model model *llama.Model
lc *llama.Context lc *llama.Context
// required for image embeddings // required for image embeddings
clip clip image *ImageContext
batchSize int batchSize int
@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
s.lc.SetCrossAttention(false)
flushPending(seq) flushPending(seq)
seq.doneReason = reason seq.doneReason = reason
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
if s.clip.cc != nil {
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
}
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
} }
@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer tokenBatch.Free() defer tokenBatch.Free()
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs)) embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
defer embedBatch.Free() defer embedBatch.Free()
for { for {
@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
for _, input := range seq.inputs {
if input.embed != nil {
s.lc.SetCrossAttention(true)
break
}
}
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return return
} }
s.seqs[i] = seq s.seqs[i] = seq
s.cond.Signal() s.cond.Signal()
break break
@ -815,7 +786,7 @@ func (s *Server) loadModel(
if ppath != "" { if ppath != "" {
var err error var err error
s.clip.cc, err = llama.NewClipContext(ppath) s.image, err = NewImageContext(s.lc, ppath)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n currMsgIdx := n
if isMllama { for cnt, msg := range msgs[currMsgIdx:] {
lastMsgIdx := len(msgs) - 1 prefix := ""
for i := lastMsgIdx; i >= currMsgIdx; i-- { imgPrompt := ""
if len(msgs[i].Images) > 0 { prompt := msg.Content
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
for _, i := range msg.Images {
var imgData llm.ImageData
if isMllama {
data, aspectRatioID, err := imageproc.Preprocess(i)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err return "", nil, err
} }
imgData := llm.ImageData{ imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(), Data: buf.Bytes(),
AspectRatioID: aspectRatioID, AspectRatioID: aspectRatioID,
} }
imgPrompt = "<|image|>"
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content) } else {
images = append(images, imgData) imgData = llm.ImageData{
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images), ID: len(images),
Data: i, Data: i,
} }
imgPrompt = " "
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window

View file

@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}}, {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
}, },
expect: expect{ expect: expect{
prompt: "<|image|>How many hotdogs are in this image? ", prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },
@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
}, },
expect: expect{ expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },
@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
}, },
expect: expect{ expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf2}, images: [][]byte{imgBuf, imgBuf2},
aspectRatioID: 1, aspectRatioID: 1,
}, },
}, },
@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Which ones have mustard?"}, {Role: "user", Content: "Which ones have mustard?"},
}, },
expect: expect{ expect: expect{
prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ", prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },

View file

@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID} images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID}
} else { } else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]} images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
} }
@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
for _, i := range images { for _, i := range images {
imgPrompt := ""
if isMllama { if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"}) imgPrompt = "<|image|>"
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
} }
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
} }
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})