runner.go: Better abstract vision model integration
-Update mllama to take the cross attention state as embeddings in a batch, more similar to how Llava handles it. This improves integration with the input cache. -Pass locations in a prompt for embeddings using tags similar to Llava. -Abstract interface to vision models so the main runner accesses Clip and Mllama similarly Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
parent
712e99d477
commit
c826e57475
13 changed files with 534 additions and 454 deletions
105
llama/llama.cpp
vendored
105
llama/llama.cpp
vendored
|
@ -2699,7 +2699,7 @@ struct llama_hparams {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool cross_attention_layer(uint32_t il) const {
|
bool cross_attention_layers(uint32_t il) const {
|
||||||
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2731,6 +2731,9 @@ struct llama_cparams {
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
bool flash_attn;
|
bool flash_attn;
|
||||||
bool no_perf;
|
bool no_perf;
|
||||||
|
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||||
|
// and not set on the context for all batches.
|
||||||
|
bool cross_attn = false;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@ -3542,10 +3545,6 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||||
|
|
||||||
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
|
||||||
// and not set on the context for all batches.
|
|
||||||
float * cross_attn_state = nullptr;
|
|
||||||
bool cross_attn_state_first_pass = true;
|
|
||||||
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init(
|
||||||
|
|
||||||
for (int i = 0; i < (int) n_layer; i++) {
|
for (int i = 0; i < (int) n_layer; i++) {
|
||||||
// for cross attention layers
|
// for cross attention layers
|
||||||
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
|
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
|
||||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||||
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
|
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
|
||||||
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
|
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
|
||||||
|
@ -7389,7 +7388,7 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
if (hparams.cross_attention_layer(i)) {
|
if (hparams.cross_attention_layers(i)) {
|
||||||
layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
|
layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
|
||||||
layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
|
layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
|
||||||
layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
|
layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
|
||||||
|
@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||||
|
|
||||||
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
||||||
} else {
|
} else {
|
||||||
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
||||||
inpL = lctx.inp_embd;
|
inpL = lctx.inp_embd;
|
||||||
ggml_set_input(lctx.inp_embd);
|
ggml_set_input(lctx.inp_embd);
|
||||||
}
|
}
|
||||||
|
@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state(
|
||||||
const llm_build_cb & cb) {
|
const llm_build_cb & cb) {
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
struct ggml_tensor * inpCAS;
|
struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
||||||
lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
cb(inpCAS, "inp_cross_attn_state", -1);
|
||||||
cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
|
ggml_set_input(inpCAS);
|
||||||
ggml_set_input(lctx.inp_cross_attn_state);
|
lctx.inp_cross_attn_state = inpCAS;
|
||||||
inpCAS = lctx.inp_cross_attn_state;
|
|
||||||
|
|
||||||
return inpCAS;
|
return inpCAS;
|
||||||
}
|
}
|
||||||
|
@ -10979,8 +10977,8 @@ struct llm_build_context {
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
if (hparams.cross_attention_layer(il)) {
|
if (hparams.cross_attention_layers(il)) {
|
||||||
if (!lctx.cross_attn_state) {
|
if (!batch.embd && !cparams.cross_attn) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10991,42 +10989,28 @@ struct llm_build_context {
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
|
||||||
cb(Qcur, "Qcur", il);
|
|
||||||
|
|
||||||
// TODO: is this required?
|
|
||||||
Qcur = ggml_cont(ctx0, Qcur);
|
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
struct ggml_tensor * Kcur;
|
struct ggml_tensor * Kcur, * Vcur;
|
||||||
if (lctx.cross_attn_state_first_pass) {
|
if (batch.embd) {
|
||||||
Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
|
Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
|
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||||
cb(Kcur, "Kcur", il);
|
|
||||||
|
|
||||||
// TODO: is this required?
|
|
||||||
Kcur = ggml_cont(ctx0, Kcur);
|
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
|
||||||
} else {
|
|
||||||
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
|
||||||
cb(Kcur, "Kcur (view)", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * Vcur;
|
|
||||||
if (lctx.cross_attn_state_first_pass) {
|
|
||||||
Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
|
Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
@ -11038,6 +11022,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
|
||||||
} else {
|
} else {
|
||||||
|
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
||||||
|
cb(Kcur, "Kcur (view)", il);
|
||||||
|
|
||||||
Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
|
Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
|
||||||
cb(Vcur, "Vcur (view)", il);
|
cb(Vcur, "Vcur (view)", il);
|
||||||
}
|
}
|
||||||
|
@ -11045,11 +11032,8 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
|
|
||||||
cb(kq, "kq_scaled", il);
|
|
||||||
|
|
||||||
// TODO: apply causal masks
|
// TODO: apply causal masks
|
||||||
struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
|
struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||||
cb(kq_soft_max, "kq_soft_max", il);
|
cb(kq_soft_max, "kq_soft_max", il);
|
||||||
|
|
||||||
Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
|
Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
|
||||||
|
@ -11139,8 +11123,8 @@ struct llm_build_context {
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.embd) {
|
if (batch.embd) {
|
||||||
const int64_t n_embd = hparams.n_embd;
|
if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
|
||||||
|
// zero out inp_embd since it's not used
|
||||||
|
float * inp_embd_data = (float *)lctx.inp_embd->data;
|
||||||
|
for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
|
||||||
|
inp_embd_data[i] = 0.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.pos && lctx.inp_pos) {
|
if (batch.pos && lctx.inp_pos) {
|
||||||
|
@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): this might copy a lot of data on every request of a
|
|
||||||
// single generation even though it doesn't change, so we should
|
|
||||||
// find a way to not set this more than one time per image
|
|
||||||
if (lctx.inp_cross_attn_state &&
|
|
||||||
lctx.inp_cross_attn_state->buffer) {
|
|
||||||
ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
@ -17789,7 +17774,7 @@ static int llama_decode_internal(
|
||||||
n_outputs = 1;
|
n_outputs = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
|
||||||
/* simple_split */ !kv_self.recurrent,
|
/* simple_split */ !kv_self.recurrent,
|
||||||
/* logits_all */ n_outputs == n_tokens_all);
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
|
@ -17899,10 +17884,6 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
llama_set_inputs(lctx, ubatch);
|
llama_set_inputs(lctx, ubatch);
|
||||||
|
|
||||||
// TODO: replace with something better to find out if its
|
|
||||||
// our first actual pass
|
|
||||||
lctx.cross_attn_state_first_pass = false;
|
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, n_threads, threadpool);
|
llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer
|
||||||
|
@ -18086,7 +18067,7 @@ static int llama_encode_internal(
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||||
|
|
||||||
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
||||||
|
|
||||||
|
@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
|
|
||||||
ctx->cross_attn_state_first_pass = true;
|
|
||||||
ctx->cross_attn_state = cross_attn_state;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_free(struct llama_context * ctx) {
|
void llama_free(struct llama_context * ctx) {
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||||
ctx->cparams.causal_attn = causal_attn;
|
ctx->cparams.causal_attn = causal_attn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
|
||||||
|
ctx->cparams.cross_attn = cross_attention;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_get_one(
|
struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
|
@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one(
|
||||||
/*n_tokens =*/ n_tokens,
|
/*n_tokens =*/ n_tokens,
|
||||||
/*tokens =*/ tokens,
|
/*tokens =*/ tokens,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
|
/*n_embd =*/ 0,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
/*n_seq_id =*/ nullptr,
|
/*n_seq_id =*/ nullptr,
|
||||||
/*seq_id =*/ nullptr,
|
/*seq_id =*/ nullptr,
|
||||||
|
@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
/*n_tokens =*/ 0,
|
/*n_tokens =*/ 0,
|
||||||
/*tokens =*/ nullptr,
|
/*tokens =*/ nullptr,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
|
/*n_embd =*/ 0,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
/*n_seq_id =*/ nullptr,
|
/*n_seq_id =*/ nullptr,
|
||||||
/*seq_id =*/ nullptr,
|
/*seq_id =*/ nullptr,
|
||||||
|
@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
|
|
||||||
if (embd) {
|
if (embd) {
|
||||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||||
|
batch.n_embd = embd;
|
||||||
} else {
|
} else {
|
||||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||||
}
|
}
|
||||||
|
|
149
llama/llama.go
149
llama/llama.go
|
@ -111,6 +111,28 @@ func PrintSystemInfo() string {
|
||||||
return C.GoString(C.llama_print_system_info()) + compiler
|
return C.GoString(C.llama_print_system_info()) + compiler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetModelArch(modelPath string) (string, error) {
|
||||||
|
mp := C.CString(modelPath)
|
||||||
|
defer C.free(unsafe.Pointer(mp))
|
||||||
|
|
||||||
|
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
|
||||||
|
if gguf_ctx == nil {
|
||||||
|
return "", errors.New("unable to load model file")
|
||||||
|
}
|
||||||
|
defer C.gguf_free(gguf_ctx)
|
||||||
|
|
||||||
|
key := C.CString("general.architecture")
|
||||||
|
defer C.free(unsafe.Pointer(key))
|
||||||
|
arch_index := C.gguf_find_key(gguf_ctx, key)
|
||||||
|
if int(arch_index) < 0 {
|
||||||
|
return "", errors.New("unknown model architecture")
|
||||||
|
}
|
||||||
|
|
||||||
|
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
|
||||||
|
|
||||||
|
return C.GoString(arch), nil
|
||||||
|
}
|
||||||
|
|
||||||
type ContextParams struct {
|
type ContextParams struct {
|
||||||
c C.struct_llama_context_params
|
c C.struct_llama_context_params
|
||||||
}
|
}
|
||||||
|
@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// llava
|
// vision processing
|
||||||
type ClipContext struct {
|
type ClipContext struct {
|
||||||
c *C.struct_clip_ctx
|
c *C.struct_clip_ctx
|
||||||
m *C.struct_mllama_ctx
|
|
||||||
IsMllama bool
|
|
||||||
embedPin runtime.Pinner
|
|
||||||
pinned bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getVisionArch(mp *C.char) (string, error) {
|
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
|
||||||
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
|
|
||||||
if gguf_ctx == nil {
|
|
||||||
return "", errors.New("unable to load vision projector")
|
|
||||||
}
|
|
||||||
defer C.gguf_free(gguf_ctx)
|
|
||||||
|
|
||||||
arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
|
|
||||||
if int(arch_index) < 0 {
|
|
||||||
return "", errors.New("unknown vision model architecture")
|
|
||||||
}
|
|
||||||
|
|
||||||
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
|
|
||||||
|
|
||||||
return C.GoString(arch), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClipContext(modelPath string) (*ClipContext, error) {
|
|
||||||
mp := C.CString(modelPath)
|
mp := C.CString(modelPath)
|
||||||
defer C.free(unsafe.Pointer(mp))
|
defer C.free(unsafe.Pointer(mp))
|
||||||
|
c := C.clip_model_load(mp, 1)
|
||||||
|
|
||||||
arch, err := getVisionArch(mp)
|
projEmbedSize := int(C.clip_n_mmproj_embd(c))
|
||||||
if err != nil {
|
modelEmbedSize := llamaContext.Model().NEmbd()
|
||||||
return nil, err
|
if projEmbedSize != modelEmbedSize {
|
||||||
|
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cc ClipContext
|
return &ClipContext{c: c}, nil
|
||||||
if arch == "clip" {
|
|
||||||
cc.c = C.clip_model_load(mp, 1)
|
|
||||||
} else if arch == "mllama" {
|
|
||||||
cc.m = C.mllama_model_load(mp, 1)
|
|
||||||
cc.IsMllama = true
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
// XXX: check embedding size?
|
|
||||||
return &cc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClipContext) Free() {
|
func (c *ClipContext) Free() {
|
||||||
if c.c != nil {
|
C.clip_free(c.c)
|
||||||
C.clip_free(c.c)
|
|
||||||
}
|
|
||||||
if c.m != nil {
|
|
||||||
C.mllama_free(c.m)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
|
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
|
||||||
c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
|
l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
|
||||||
|
|
||||||
numTokens := int(c.n_image_pos)
|
numTokens := int(l.n_image_pos)
|
||||||
numEmbed := llamaContext.Model().NEmbd()
|
numEmbed := llamaContext.Model().NEmbd()
|
||||||
|
|
||||||
s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens)
|
s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
|
||||||
|
|
||||||
embed := make([][]float32, numTokens)
|
embed := make([][]float32, numTokens)
|
||||||
rows := make([]float32, len(s))
|
rows := make([]float32, len(s))
|
||||||
|
@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
|
||||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||||
}
|
}
|
||||||
|
|
||||||
C.llava_image_embed_free(c)
|
C.llava_image_embed_free(l)
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
|
type MllamaContext struct {
|
||||||
|
c *C.struct_mllama_ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
|
||||||
|
mp := C.CString(modelPath)
|
||||||
|
defer C.free(unsafe.Pointer(mp))
|
||||||
|
c := C.mllama_model_load(mp, 1)
|
||||||
|
|
||||||
|
projEmbedSize := int(C.mllama_n_embd(c))
|
||||||
|
modelEmbedSize := llamaContext.Model().NEmbd()
|
||||||
|
if projEmbedSize != modelEmbedSize {
|
||||||
|
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MllamaContext{c: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MllamaContext) Free() {
|
||||||
|
C.mllama_free(m.c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
|
||||||
img := C.mllama_image_init()
|
img := C.mllama_image_init()
|
||||||
defer C.mllama_image_free(img)
|
defer C.mllama_image_free(img)
|
||||||
|
|
||||||
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
|
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
|
||||||
|
|
||||||
numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
|
rows := make([]float32, m.EmbedSize(llamaContext))
|
||||||
numEmbed := llamaContext.Model().NEmbd()
|
C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
|
||||||
|
|
||||||
rows := make([]float32, numEmbed*numTokens)
|
embed := make([][]float32, 1)
|
||||||
C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
|
embed[0] = rows
|
||||||
|
|
||||||
embed := make([][]float32, numTokens)
|
|
||||||
for i := range embed {
|
|
||||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
|
||||||
}
|
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
}
|
}
|
||||||
|
|
||||||
// This really needs to be set on a batch instead
|
func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
|
||||||
func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
|
numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
|
||||||
if embed != nil {
|
numEmbed := llamaContext.Model().NEmbd()
|
||||||
if clipContext.pinned {
|
|
||||||
panic("Cross attention state already pinned")
|
|
||||||
}
|
|
||||||
|
|
||||||
embedData := &embed[0][0]
|
return numTokens * numEmbed
|
||||||
clipContext.embedPin.Pin(embedData)
|
}
|
||||||
clipContext.pinned = true
|
|
||||||
|
|
||||||
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
|
func (c *Context) SetCrossAttention(state bool) {
|
||||||
} else {
|
C.llama_set_cross_attention(c.c, C.bool(state))
|
||||||
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
|
|
||||||
|
|
||||||
if clipContext.pinned {
|
|
||||||
clipContext.embedPin.Unpin()
|
|
||||||
clipContext.pinned = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
|
|
3
llama/llama.h
vendored
3
llama/llama.h
vendored
|
@ -266,6 +266,7 @@ extern "C" {
|
||||||
|
|
||||||
llama_token * token;
|
llama_token * token;
|
||||||
float * embd;
|
float * embd;
|
||||||
|
int32_t n_embd;
|
||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
|
@ -451,7 +452,7 @@ extern "C" {
|
||||||
|
|
||||||
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||||
// and not set on the context for all batches.
|
// and not set on the context for all batches.
|
||||||
LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
|
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
|
||||||
|
|
||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
2
llama/llava.cpp
vendored
2
llama/llava.cpp
vendored
|
@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||||
if (llama_decode(ctx_llama, batch)) {
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -12,27 +12,49 @@ kv cache once per run
|
||||||
|
|
||||||
remaining is to implement the cross attention mask
|
remaining is to implement the cross attention mask
|
||||||
---
|
---
|
||||||
include/llama.h | 4 +
|
examples/llava/llava.cpp | 2 +-
|
||||||
src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++--
|
include/llama.h | 5 +
|
||||||
2 files changed, 447 insertions(+), 13 deletions(-)
|
src/llama.cpp | 447 +++++++++++++++++++++++++++++++++++++--
|
||||||
|
3 files changed, 436 insertions(+), 18 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
|
||||||
|
index 8558c6bd..37b2f2e2 100644
|
||||||
|
--- a/examples/llava/llava.cpp
|
||||||
|
+++ b/examples/llava/llava.cpp
|
||||||
|
@@ -409,7 +409,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
||||||
|
if (n_eval > n_batch) {
|
||||||
|
n_eval = n_batch;
|
||||||
|
}
|
||||||
|
- llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||||
|
+ llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||||
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
|
return false;
|
||||||
diff --git a/include/llama.h b/include/llama.h
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
index 7cae1bbe..122e3cf1 100644
|
index 7cae1bbe..aca09310 100644
|
||||||
--- a/include/llama.h
|
--- a/include/llama.h
|
||||||
+++ b/include/llama.h
|
+++ b/include/llama.h
|
||||||
@@ -423,6 +423,10 @@ extern "C" {
|
@@ -240,6 +240,7 @@ extern "C" {
|
||||||
|
|
||||||
|
llama_token * token;
|
||||||
|
float * embd;
|
||||||
|
+ int32_t n_embd;
|
||||||
|
llama_pos * pos;
|
||||||
|
int32_t * n_seq_id;
|
||||||
|
llama_seq_id ** seq_id;
|
||||||
|
@@ -423,6 +424,10 @@ extern "C" {
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params);
|
struct llama_context_params params);
|
||||||
|
|
||||||
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||||
+ // and not set on the context for all batches.
|
+ // and not set on the context for all batches.
|
||||||
+ LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
|
+ LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
|
||||||
+
|
+
|
||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index 83b80b59..b189a19a 100644
|
index 83b80b59..35748488 100644
|
||||||
--- a/src/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/src/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
|
@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
|
||||||
|
@ -160,13 +182,23 @@ index 83b80b59..b189a19a 100644
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
+
|
+
|
||||||
+ bool cross_attention_layer(uint32_t il) const {
|
+ bool cross_attention_layers(uint32_t il) const {
|
||||||
+ return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
+ return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
||||||
+ }
|
+ }
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
@@ -2806,6 +2859,16 @@ struct llama_layer {
|
@@ -2652,6 +2705,9 @@ struct llama_cparams {
|
||||||
|
bool offload_kqv;
|
||||||
|
bool flash_attn;
|
||||||
|
bool no_perf;
|
||||||
|
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||||
|
+ // and not set on the context for all batches.
|
||||||
|
+ bool cross_attn = false;
|
||||||
|
|
||||||
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@@ -2806,6 +2862,16 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ffn_down_scale;
|
struct ggml_tensor * ffn_down_scale;
|
||||||
|
|
||||||
struct ggml_tensor * bskcn_tv;
|
struct ggml_tensor * bskcn_tv;
|
||||||
|
@ -183,25 +215,21 @@ index 83b80b59..b189a19a 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
// very similar to llama_batch,
|
// very similar to llama_batch,
|
||||||
@@ -3452,6 +3515,12 @@ struct llama_context {
|
@@ -3452,6 +3518,8 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||||
+
|
+
|
||||||
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
|
||||||
+ // and not set on the context for all batches.
|
|
||||||
+ float * cross_attn_state = nullptr;
|
|
||||||
+ bool cross_attn_state_first_pass = true;
|
|
||||||
+ struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
+ struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_lora_weight {
|
struct llama_lora_weight {
|
||||||
@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init(
|
@@ -3686,6 +3754,18 @@ static bool llama_kv_cache_init(
|
||||||
cache.v_l.reserve(n_layer);
|
cache.v_l.reserve(n_layer);
|
||||||
|
|
||||||
for (int i = 0; i < (int) n_layer; i++) {
|
for (int i = 0; i < (int) n_layer; i++) {
|
||||||
+ // for cross attention layers
|
+ // for cross attention layers
|
||||||
+ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
|
+ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
|
||||||
+ struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
+ struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||||
+ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
|
+ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
|
||||||
+ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
|
+ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
|
||||||
|
@ -215,7 +243,7 @@ index 83b80b59..b189a19a 100644
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
@@ -5460,12 +5541,14 @@ static void llm_load_hparams(
|
@@ -5460,12 +5540,14 @@ static void llm_load_hparams(
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero-out the per-layer hparams
|
// zero-out the per-layer hparams
|
||||||
|
@ -235,7 +263,7 @@ index 83b80b59..b189a19a 100644
|
||||||
|
|
||||||
// n_head_kv is optional, default to n_head
|
// n_head_kv is optional, default to n_head
|
||||||
hparams.n_head_kv_arr = hparams.n_head_arr;
|
hparams.n_head_kv_arr = hparams.n_head_arr;
|
||||||
@@ -5514,7 +5597,7 @@ static void llm_load_hparams(
|
@@ -5514,7 +5596,7 @@ static void llm_load_hparams(
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||||
|
|
||||||
|
@ -244,7 +272,7 @@ index 83b80b59..b189a19a 100644
|
||||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||||
}
|
}
|
||||||
@@ -5554,6 +5637,16 @@ static void llm_load_hparams(
|
@@ -5554,6 +5636,16 @@ static void llm_load_hparams(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -261,7 +289,7 @@ index 83b80b59..b189a19a 100644
|
||||||
case LLM_ARCH_MINICPM:
|
case LLM_ARCH_MINICPM:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
@@ -7249,6 +7342,55 @@ static bool llm_load_tensors(
|
@@ -7249,6 +7341,55 @@ static bool llm_load_tensors(
|
||||||
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -286,7 +314,7 @@ index 83b80b59..b189a19a 100644
|
||||||
+
|
+
|
||||||
+ auto & layer = model.layers[i];
|
+ auto & layer = model.layers[i];
|
||||||
+
|
+
|
||||||
+ if (hparams.cross_attention_layer(i)) {
|
+ if (hparams.cross_attention_layers(i)) {
|
||||||
+ layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
|
+ layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
|
||||||
+ layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
|
+ layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
|
||||||
+ layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
|
+ layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
|
||||||
|
@ -317,7 +345,7 @@ index 83b80b59..b189a19a 100644
|
||||||
case LLM_ARCH_GROK:
|
case LLM_ARCH_GROK:
|
||||||
{
|
{
|
||||||
if (n_expert == 0) {
|
if (n_expert == 0) {
|
||||||
@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
@@ -9093,7 +9234,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||||
|
|
||||||
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
||||||
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
||||||
|
@ -326,16 +354,7 @@ index 83b80b59..b189a19a 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.vocab_only) {
|
if (params.vocab_only) {
|
||||||
@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd(
|
@@ -9193,6 +9334,21 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||||
|
|
||||||
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
|
||||||
} else {
|
|
||||||
- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
|
||||||
+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
|
||||||
inpL = lctx.inp_embd;
|
|
||||||
ggml_set_input(lctx.inp_embd);
|
|
||||||
}
|
|
||||||
@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|
||||||
return inpL;
|
return inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,11 +365,10 @@ index 83b80b59..b189a19a 100644
|
||||||
+ const llm_build_cb & cb) {
|
+ const llm_build_cb & cb) {
|
||||||
+ const int64_t n_embd = hparams.n_embd;
|
+ const int64_t n_embd = hparams.n_embd;
|
||||||
+
|
+
|
||||||
+ struct ggml_tensor * inpCAS;
|
+ struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
||||||
+ lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
+ cb(inpCAS, "inp_cross_attn_state", -1);
|
||||||
+ cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
|
+ ggml_set_input(inpCAS);
|
||||||
+ ggml_set_input(lctx.inp_cross_attn_state);
|
+ lctx.inp_cross_attn_state = inpCAS;
|
||||||
+ inpCAS = lctx.inp_cross_attn_state;
|
|
||||||
+
|
+
|
||||||
+ return inpCAS;
|
+ return inpCAS;
|
||||||
+}
|
+}
|
||||||
|
@ -358,7 +376,7 @@ index 83b80b59..b189a19a 100644
|
||||||
static void llm_build_kv_store(
|
static void llm_build_kv_store(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
@@ -10167,6 +10325,7 @@ struct llm_build_context {
|
@@ -10167,6 +10323,7 @@ struct llm_build_context {
|
||||||
lctx.inp_pos_bucket = nullptr;
|
lctx.inp_pos_bucket = nullptr;
|
||||||
lctx.inp_embd_enc = nullptr;
|
lctx.inp_embd_enc = nullptr;
|
||||||
lctx.inp_KQ_mask_cross = nullptr;
|
lctx.inp_KQ_mask_cross = nullptr;
|
||||||
|
@ -366,7 +384,7 @@ index 83b80b59..b189a19a 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
void free() {
|
void free() {
|
||||||
@@ -10754,6 +10913,253 @@ struct llm_build_context {
|
@@ -10754,6 +10911,239 @@ struct llm_build_context {
|
||||||
LLM_NORM_RMS, cb, -1);
|
LLM_NORM_RMS, cb, -1);
|
||||||
cb(cur, "result_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
@ -410,8 +428,8 @@ index 83b80b59..b189a19a 100644
|
||||||
+ LLM_NORM_RMS, cb, il);
|
+ LLM_NORM_RMS, cb, il);
|
||||||
+ cb(cur, "attn_norm", il);
|
+ cb(cur, "attn_norm", il);
|
||||||
+
|
+
|
||||||
+ if (hparams.cross_attention_layer(il)) {
|
+ if (hparams.cross_attention_layers(il)) {
|
||||||
+ if (!lctx.cross_attn_state) {
|
+ if (!batch.embd && !cparams.cross_attn) {
|
||||||
+ continue;
|
+ continue;
|
||||||
+ }
|
+ }
|
||||||
+
|
+
|
||||||
|
@ -422,42 +440,28 @@ index 83b80b59..b189a19a 100644
|
||||||
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
+ cb(Qcur, "Qcur", il);
|
+ cb(Qcur, "Qcur", il);
|
||||||
+
|
+
|
||||||
+ Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
+ Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
|
||||||
+ cb(Qcur, "Qcur", il);
|
|
||||||
+
|
|
||||||
+ // TODO: is this required?
|
|
||||||
+ Qcur = ggml_cont(ctx0, Qcur);
|
|
||||||
+ cb(Qcur, "Qcur", il);
|
+ cb(Qcur, "Qcur", il);
|
||||||
+
|
+
|
||||||
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||||
+ cb(Qcur, "Qcur", il);
|
+ cb(Qcur, "Qcur", il);
|
||||||
+
|
+
|
||||||
+ struct ggml_tensor * Kcur;
|
+ struct ggml_tensor * Kcur, * Vcur;
|
||||||
+ if (lctx.cross_attn_state_first_pass) {
|
+ if (batch.embd) {
|
||||||
+ Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
|
+ Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
|
||||||
+ cb(Kcur, "Kcur", il);
|
+ cb(Kcur, "Kcur", il);
|
||||||
+
|
+
|
||||||
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
|
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
|
||||||
+ cb(Kcur, "Kcur", il);
|
+ cb(Kcur, "Kcur", il);
|
||||||
+
|
+
|
||||||
+ Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
|
+ Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||||
+ cb(Kcur, "Kcur", il);
|
|
||||||
+
|
|
||||||
+ // TODO: is this required?
|
|
||||||
+ Kcur = ggml_cont(ctx0, Kcur);
|
|
||||||
+ cb(Kcur, "Kcur", il);
|
+ cb(Kcur, "Kcur", il);
|
||||||
+
|
+
|
||||||
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||||
+ cb(Kcur, "Kcur", il);
|
+ cb(Kcur, "Kcur", il);
|
||||||
+
|
+
|
||||||
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
|
||||||
+ } else {
|
|
||||||
+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
|
||||||
+ cb(Kcur, "Kcur (view)", il);
|
|
||||||
+ }
|
|
||||||
+
|
+
|
||||||
+ struct ggml_tensor * Vcur;
|
|
||||||
+ if (lctx.cross_attn_state_first_pass) {
|
|
||||||
+ Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
|
+ Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
|
||||||
+ cb(Vcur, "Vcur", il);
|
+ cb(Vcur, "Vcur", il);
|
||||||
+
|
+
|
||||||
|
@ -469,6 +473,9 @@ index 83b80b59..b189a19a 100644
|
||||||
+
|
+
|
||||||
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
|
||||||
+ } else {
|
+ } else {
|
||||||
|
+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
||||||
|
+ cb(Kcur, "Kcur (view)", il);
|
||||||
|
+
|
||||||
+ Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
|
+ Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
|
||||||
+ cb(Vcur, "Vcur (view)", il);
|
+ cb(Vcur, "Vcur (view)", il);
|
||||||
+ }
|
+ }
|
||||||
|
@ -476,11 +483,8 @@ index 83b80b59..b189a19a 100644
|
||||||
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
|
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
|
||||||
+ cb(kq, "kq", il);
|
+ cb(kq, "kq", il);
|
||||||
+
|
+
|
||||||
+ kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
|
|
||||||
+ cb(kq, "kq_scaled", il);
|
|
||||||
+
|
|
||||||
+ // TODO: apply causal masks
|
+ // TODO: apply causal masks
|
||||||
+ struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
|
+ struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||||
+ cb(kq_soft_max, "kq_soft_max", il);
|
+ cb(kq_soft_max, "kq_soft_max", il);
|
||||||
+
|
+
|
||||||
+ Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
|
+ Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
|
||||||
|
@ -570,8 +574,8 @@ index 83b80b59..b189a19a 100644
|
||||||
+ cb(Kcur, "Kcur", il);
|
+ cb(Kcur, "Kcur", il);
|
||||||
+
|
+
|
||||||
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||||
+ model.layers[il].wo, model.layers[il].bo,
|
+ model.layers[il].wo, model.layers[il].bo,
|
||||||
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
+
|
+
|
||||||
+
|
+
|
||||||
+ if (il == n_layer - 1) {
|
+ if (il == n_layer - 1) {
|
||||||
|
@ -620,7 +624,7 @@ index 83b80b59..b189a19a 100644
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph(
|
@@ -16501,6 +16891,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
{
|
{
|
||||||
result = llm.build_llama();
|
result = llm.build_llama();
|
||||||
} break;
|
} break;
|
||||||
|
@ -631,33 +635,48 @@ index 83b80b59..b189a19a 100644
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
{
|
{
|
||||||
result = llm.build_baichuan();
|
result = llm.build_baichuan();
|
||||||
@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
@@ -16761,10 +17155,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
+ // TODO (jmorganca): this might copy a lot of data on every request of a
|
if (batch.embd) {
|
||||||
+ // single generation even though it doesn't change, so we should
|
- const int64_t n_embd = hparams.n_embd;
|
||||||
+ // find a way to not set this more than one time per image
|
- const int64_t n_tokens = batch.n_tokens;
|
||||||
+ if (lctx.inp_cross_attn_state &&
|
+ if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
|
||||||
+ lctx.inp_cross_attn_state->buffer) {
|
+ ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
|
||||||
+ ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
|
+ // zero out inp_embd since it's not used
|
||||||
+ }
|
+ float * inp_embd_data = (float *)lctx.inp_embd->data;
|
||||||
+
|
+ for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
|
||||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
+ inp_embd_data[i] = 0.0f;
|
||||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
+ }
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
+ } else {
|
||||||
@@ -17455,6 +17873,10 @@ static int llama_decode_internal(
|
+ const int64_t n_embd = hparams.n_embd;
|
||||||
|
+ const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
llama_set_inputs(lctx, ubatch);
|
- ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||||
|
+ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
+ // TODO: replace with something better to find out if its
|
if (batch.pos && lctx.inp_pos) {
|
||||||
+ // our first actual pass
|
@@ -17345,7 +17748,7 @@ static int llama_decode_internal(
|
||||||
+ lctx.cross_attn_state_first_pass = false;
|
n_outputs = 1;
|
||||||
+
|
}
|
||||||
llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
||||||
|
|
||||||
// update the kv ring buffer
|
- lctx.sbatch.from_batch(batch_all, n_embd,
|
||||||
@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
+ lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
|
||||||
|
/* simple_split */ !kv_self.recurrent,
|
||||||
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
|
@@ -17638,7 +18041,7 @@ static int llama_encode_internal(
|
||||||
|
|
||||||
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
|
- lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||||
|
+ lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||||
|
|
||||||
|
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
||||||
|
|
||||||
|
@@ -18648,7 +19051,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
if (llama_model_has_encoder(&model)) {
|
if (llama_model_has_encoder(&model)) {
|
||||||
n_attn_layer *= 3;
|
n_attn_layer *= 3;
|
||||||
}
|
}
|
||||||
|
@ -668,19 +687,7 @@ index 83b80b59..b189a19a 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total_size_org = 0;
|
size_t total_size_org = 0;
|
||||||
@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model(
|
@@ -19814,6 +20219,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
return ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
|
|
||||||
+ ctx->cross_attn_state_first_pass = true;
|
|
||||||
+ ctx->cross_attn_state = cross_attn_state;
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
void llama_free(struct llama_context * ctx) {
|
|
||||||
delete ctx;
|
|
||||||
}
|
|
||||||
@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
||||||
|
|
||||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
|
@ -688,3 +695,38 @@ index 83b80b59..b189a19a 100644
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
case LLM_ARCH_STARCODER:
|
case LLM_ARCH_STARCODER:
|
||||||
case LLM_ARCH_PLAMO:
|
case LLM_ARCH_PLAMO:
|
||||||
|
@@ -21230,6 +21636,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||||
|
ctx->cparams.causal_attn = causal_attn;
|
||||||
|
}
|
||||||
|
|
||||||
|
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
|
||||||
|
+ ctx->cparams.cross_attn = cross_attention;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
struct llama_batch llama_batch_get_one(
|
||||||
|
llama_token * tokens,
|
||||||
|
int32_t n_tokens,
|
||||||
|
@@ -21239,6 +21649,7 @@ struct llama_batch llama_batch_get_one(
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ tokens,
|
||||||
|
/*embd =*/ nullptr,
|
||||||
|
+ /*n_embd =*/ 0,
|
||||||
|
/*pos =*/ nullptr,
|
||||||
|
/*n_seq_id =*/ nullptr,
|
||||||
|
/*seq_id =*/ nullptr,
|
||||||
|
@@ -21254,6 +21665,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
|
/*n_tokens =*/ 0,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ nullptr,
|
||||||
|
+ /*n_embd =*/ 0,
|
||||||
|
/*pos =*/ nullptr,
|
||||||
|
/*n_seq_id =*/ nullptr,
|
||||||
|
/*seq_id =*/ nullptr,
|
||||||
|
@@ -21265,6 +21677,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
|
|
||||||
|
if (embd) {
|
||||||
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||||
|
+ batch.n_embd = embd;
|
||||||
|
} else {
|
||||||
|
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||||
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"hash/maphash"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
@ -20,10 +19,6 @@ type InputCache struct {
|
||||||
// optimize cache eviction for multiple users
|
// optimize cache eviction for multiple users
|
||||||
multiUserCache bool
|
multiUserCache bool
|
||||||
|
|
||||||
// cache of images to embeddings
|
|
||||||
images []imageCache
|
|
||||||
imageHash maphash.Hash
|
|
||||||
|
|
||||||
lc *llama.Context
|
lc *llama.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
|
||||||
numCtx: kvSize / numSlots,
|
numCtx: kvSize / numSlots,
|
||||||
slots: slots,
|
slots: slots,
|
||||||
multiUserCache: multiUserCache,
|
multiUserCache: multiUserCache,
|
||||||
images: make([]imageCache, numSlots),
|
|
||||||
lc: lc,
|
lc: lc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar
|
||||||
}
|
}
|
||||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
|
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Locking: Lookup and store operations on imageCache require a lock
|
|
||||||
// to be held that serializes these with each other. Hash does not
|
|
||||||
// require a lock nor they need to be serialized with InputCacheSlot.
|
|
||||||
|
|
||||||
type imageCache struct {
|
|
||||||
key uint64
|
|
||||||
val [][]float32
|
|
||||||
lastUsed time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *InputCache) HashImage(image []byte) uint64 {
|
|
||||||
c.imageHash.Reset()
|
|
||||||
_, _ = c.imageHash.Write(image)
|
|
||||||
return c.imageHash.Sum64()
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrImageNotFound = errors.New("image not found in cache")
|
|
||||||
|
|
||||||
func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
|
|
||||||
for i := range c.images {
|
|
||||||
if c.images[i].key == hash {
|
|
||||||
slog.Debug("loading image embeddings from cache", "entry", i)
|
|
||||||
c.images[i].lastUsed = time.Now()
|
|
||||||
return c.images[i].val, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrImageNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
|
|
||||||
best := time.Now()
|
|
||||||
var bestImage int
|
|
||||||
|
|
||||||
for i := range c.images {
|
|
||||||
if c.images[i].key == hash {
|
|
||||||
bestImage = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.images[i].lastUsed.Compare(best) < 0 {
|
|
||||||
best = c.images[i].lastUsed
|
|
||||||
bestImage = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
|
||||||
c.images[bestImage].key = hash
|
|
||||||
c.images[bestImage].val = embed
|
|
||||||
c.images[bestImage].lastUsed = time.Now()
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImageCache(t *testing.T) {
|
|
||||||
cache := NewInputCache(nil, 2048, 4, false)
|
|
||||||
|
|
||||||
valA := [][]float32{{0.1, 0.2}, {0.3}}
|
|
||||||
valB := [][]float32{{0.4}, {0.5}, {0.6}}
|
|
||||||
valC := [][]float32{{0.7}}
|
|
||||||
valD := [][]float32{{0.8}}
|
|
||||||
valE := [][]float32{{0.9}}
|
|
||||||
|
|
||||||
// Empty cache
|
|
||||||
result, err := cache.FindImage(0x5adb61d31933a946)
|
|
||||||
if err != ErrImageNotFound {
|
|
||||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert A
|
|
||||||
cache.AddImage(0x5adb61d31933a946, valA)
|
|
||||||
|
|
||||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
|
||||||
if !reflect.DeepEqual(result, valA) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert B
|
|
||||||
cache.AddImage(0x011551369a34a901, valB)
|
|
||||||
|
|
||||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
|
||||||
if !reflect.DeepEqual(result, valA) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0x011551369a34a901)
|
|
||||||
if !reflect.DeepEqual(result, valB) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace B with C
|
|
||||||
cache.AddImage(0x011551369a34a901, valC)
|
|
||||||
|
|
||||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
|
||||||
if !reflect.DeepEqual(result, valA) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0x011551369a34a901)
|
|
||||||
if !reflect.DeepEqual(result, valC) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evict A
|
|
||||||
cache.AddImage(0x756b218a517e7353, valB)
|
|
||||||
cache.AddImage(0x75e5e8d35d7e3967, valD)
|
|
||||||
cache.AddImage(0xd96f7f268ca0646e, valE)
|
|
||||||
|
|
||||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
|
||||||
if reflect.DeepEqual(result, valA) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0x756b218a517e7353)
|
|
||||||
if !reflect.DeepEqual(result, valB) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0x011551369a34a901)
|
|
||||||
if !reflect.DeepEqual(result, valC) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0x75e5e8d35d7e3967)
|
|
||||||
if !reflect.DeepEqual(result, valD) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
result, err = cache.FindImage(0xd96f7f268ca0646e)
|
|
||||||
if !reflect.DeepEqual(result, valE) {
|
|
||||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
145
llama/runner/image.go
Normal file
145
llama/runner/image.go
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/maphash"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
|
)
|
||||||
|
|
||||||
|
const imageCacheSize = 4
|
||||||
|
|
||||||
|
type ImageContext struct {
|
||||||
|
// mu is required to be held when generating embeddings or accessing the cache
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
clip *llama.ClipContext
|
||||||
|
mllama *llama.MllamaContext
|
||||||
|
|
||||||
|
// cache of images to embeddings
|
||||||
|
images []imageCache
|
||||||
|
imageHash maphash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
|
||||||
|
arch, err := llama.GetModelArch(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
var c ImageContext
|
||||||
|
if arch == "clip" {
|
||||||
|
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
|
||||||
|
} else if arch == "mllama" {
|
||||||
|
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.images = make([]imageCache, imageCacheSize)
|
||||||
|
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) Free(modelPath string) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.clip != nil {
|
||||||
|
c.clip.Free()
|
||||||
|
}
|
||||||
|
if c.mllama != nil {
|
||||||
|
c.mllama.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := c.hashImage(data)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
embed, err := c.findImage(hash)
|
||||||
|
if err != nil {
|
||||||
|
if c.mllama != nil {
|
||||||
|
embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
|
||||||
|
} else if c.clip != nil {
|
||||||
|
embed = c.clip.NewEmbed(llamaContext, data)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.addImage(hash, embed)
|
||||||
|
}
|
||||||
|
|
||||||
|
return embed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
||||||
|
if c != nil && c.mllama != nil {
|
||||||
|
return c.mllama.EmbedSize(llamaContext)
|
||||||
|
} else {
|
||||||
|
return llamaContext.Model().NEmbd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageCache struct {
|
||||||
|
key uint64
|
||||||
|
val [][]float32
|
||||||
|
lastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) hashImage(image []byte) uint64 {
|
||||||
|
c.imageHash.Reset()
|
||||||
|
_, _ = c.imageHash.Write(image)
|
||||||
|
return c.imageHash.Sum64()
|
||||||
|
}
|
||||||
|
|
||||||
|
var errImageNotFound = errors.New("image not found in cache")
|
||||||
|
|
||||||
|
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
|
||||||
|
for i := range c.images {
|
||||||
|
if c.images[i].key == hash {
|
||||||
|
slog.Debug("loading image embeddings from cache", "entry", i)
|
||||||
|
c.images[i].lastUsed = time.Now()
|
||||||
|
return c.images[i].val, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errImageNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
|
||||||
|
best := time.Now()
|
||||||
|
var bestImage int
|
||||||
|
|
||||||
|
for i := range c.images {
|
||||||
|
if c.images[i].key == hash {
|
||||||
|
bestImage = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.images[i].lastUsed.Compare(best) < 0 {
|
||||||
|
best = c.images[i].lastUsed
|
||||||
|
bestImage = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
||||||
|
c.images[bestImage].key = hash
|
||||||
|
c.images[bestImage].val = embed
|
||||||
|
c.images[bestImage].lastUsed = time.Now()
|
||||||
|
}
|
80
llama/runner/image_test.go
Normal file
80
llama/runner/image_test.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestImageCache(t *testing.T) {
|
||||||
|
cache := ImageContext{images: make([]imageCache, 4)}
|
||||||
|
|
||||||
|
valA := [][]float32{{0.1, 0.2}, {0.3}}
|
||||||
|
valB := [][]float32{{0.4}, {0.5}, {0.6}}
|
||||||
|
valC := [][]float32{{0.7}}
|
||||||
|
valD := [][]float32{{0.8}}
|
||||||
|
valE := [][]float32{{0.9}}
|
||||||
|
|
||||||
|
// Empty cache
|
||||||
|
result, err := cache.findImage(0x5adb61d31933a946)
|
||||||
|
if err != errImageNotFound {
|
||||||
|
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert A
|
||||||
|
cache.addImage(0x5adb61d31933a946, valA)
|
||||||
|
|
||||||
|
result, err = cache.findImage(0x5adb61d31933a946)
|
||||||
|
if !reflect.DeepEqual(result, valA) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert B
|
||||||
|
cache.addImage(0x011551369a34a901, valB)
|
||||||
|
|
||||||
|
result, err = cache.findImage(0x5adb61d31933a946)
|
||||||
|
if !reflect.DeepEqual(result, valA) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0x011551369a34a901)
|
||||||
|
if !reflect.DeepEqual(result, valB) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace B with C
|
||||||
|
cache.addImage(0x011551369a34a901, valC)
|
||||||
|
|
||||||
|
result, err = cache.findImage(0x5adb61d31933a946)
|
||||||
|
if !reflect.DeepEqual(result, valA) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0x011551369a34a901)
|
||||||
|
if !reflect.DeepEqual(result, valC) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict A
|
||||||
|
cache.addImage(0x756b218a517e7353, valB)
|
||||||
|
cache.addImage(0x75e5e8d35d7e3967, valD)
|
||||||
|
cache.addImage(0xd96f7f268ca0646e, valE)
|
||||||
|
|
||||||
|
result, err = cache.findImage(0x5adb61d31933a946)
|
||||||
|
if reflect.DeepEqual(result, valA) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0x756b218a517e7353)
|
||||||
|
if !reflect.DeepEqual(result, valB) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0x011551369a34a901)
|
||||||
|
if !reflect.DeepEqual(result, valC) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0x75e5e8d35d7e3967)
|
||||||
|
if !reflect.DeepEqual(result, valD) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
result, err = cache.findImage(0xd96f7f268ca0646e)
|
||||||
|
if !reflect.DeepEqual(result, valE) {
|
||||||
|
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := s.cache.HashImage(images[imageIndex].Data)
|
embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
|
||||||
|
|
||||||
// Vision models cannot be accessed concurrently
|
|
||||||
s.clip.mu.Lock()
|
|
||||||
embed, err := s.cache.FindImage(hash)
|
|
||||||
if err != nil {
|
|
||||||
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
|
|
||||||
s.cache.AddImage(hash, embed)
|
|
||||||
}
|
|
||||||
s.clip.mu.Unlock()
|
|
||||||
|
|
||||||
for _, e := range embed {
|
for _, e := range embed {
|
||||||
inputs = append(inputs, input{embed: e})
|
inputs = append(inputs, input{embed: e})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.clip.cc != nil {
|
|
||||||
var embed [][]float32
|
|
||||||
|
|
||||||
if s.clip.cc.IsMllama && len(images) >= 1 {
|
|
||||||
hash := s.cache.HashImage(images[0].Data)
|
|
||||||
|
|
||||||
s.clip.mu.Lock()
|
|
||||||
var err error
|
|
||||||
embed, err = s.cache.FindImage(hash)
|
|
||||||
if err != nil {
|
|
||||||
embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
|
|
||||||
s.cache.AddImage(hash, embed)
|
|
||||||
}
|
|
||||||
s.clip.mu.Unlock()
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return inputs, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type clip struct {
|
|
||||||
cc *llama.ClipContext
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
model *llama.Model
|
model *llama.Model
|
||||||
lc *llama.Context
|
lc *llama.Context
|
||||||
|
|
||||||
// required for image embeddings
|
// required for image embeddings
|
||||||
clip clip
|
image *ImageContext
|
||||||
|
|
||||||
batchSize int
|
batchSize int
|
||||||
|
|
||||||
|
@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
|
s.lc.SetCrossAttention(false)
|
||||||
flushPending(seq)
|
flushPending(seq)
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
seq.cache.InUse = false
|
||||||
if s.clip.cc != nil {
|
|
||||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
|
|
||||||
}
|
|
||||||
s.seqs[seqIndex] = nil
|
s.seqs[seqIndex] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
|
||||||
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
||||||
defer tokenBatch.Free()
|
defer tokenBatch.Free()
|
||||||
|
|
||||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs))
|
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
|
||||||
defer embedBatch.Free()
|
defer embedBatch.Free()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
|
for _, input := range seq.inputs {
|
||||||
|
if input.embed != nil {
|
||||||
|
s.lc.SetCrossAttention(true)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.seqs[i] = seq
|
s.seqs[i] = seq
|
||||||
s.cond.Signal()
|
s.cond.Signal()
|
||||||
break
|
break
|
||||||
|
@ -815,7 +786,7 @@ func (s *Server) loadModel(
|
||||||
|
|
||||||
if ppath != "" {
|
if ppath != "" {
|
||||||
var err error
|
var err error
|
||||||
s.clip.cc, err = llama.NewClipContext(ppath)
|
s.image, err = NewImageContext(s.lc, ppath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
|
|
||||||
currMsgIdx := n
|
currMsgIdx := n
|
||||||
|
|
||||||
if isMllama {
|
for cnt, msg := range msgs[currMsgIdx:] {
|
||||||
lastMsgIdx := len(msgs) - 1
|
prefix := ""
|
||||||
for i := lastMsgIdx; i >= currMsgIdx; i-- {
|
imgPrompt := ""
|
||||||
if len(msgs[i].Images) > 0 {
|
prompt := msg.Content
|
||||||
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
|
|
||||||
|
for _, i := range msg.Images {
|
||||||
|
var imgData llm.ImageData
|
||||||
|
|
||||||
|
if isMllama {
|
||||||
|
data, aspectRatioID, err := imageproc.Preprocess(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
imgData := llm.ImageData{
|
imgData = llm.ImageData{
|
||||||
|
ID: len(images),
|
||||||
Data: buf.Bytes(),
|
Data: buf.Bytes(),
|
||||||
AspectRatioID: aspectRatioID,
|
AspectRatioID: aspectRatioID,
|
||||||
}
|
}
|
||||||
|
imgPrompt = "<|image|>"
|
||||||
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
|
} else {
|
||||||
images = append(images, imgData)
|
imgData = llm.ImageData{
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for cnt, msg := range msgs[currMsgIdx:] {
|
|
||||||
prefix := ""
|
|
||||||
prompt := msg.Content
|
|
||||||
for _, i := range msg.Images {
|
|
||||||
imgData := llm.ImageData{
|
|
||||||
ID: len(images),
|
ID: len(images),
|
||||||
Data: i,
|
Data: i,
|
||||||
}
|
}
|
||||||
|
imgPrompt = " "
|
||||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
|
||||||
if !strings.Contains(prompt, "[img]") {
|
|
||||||
prefix += imgTag
|
|
||||||
} else {
|
|
||||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
images = append(images, imgData)
|
|
||||||
}
|
}
|
||||||
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
|
|
||||||
|
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||||
|
if !strings.Contains(prompt, "[img]") {
|
||||||
|
prefix += imgTag
|
||||||
|
} else {
|
||||||
|
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
images = append(images, imgData)
|
||||||
}
|
}
|
||||||
|
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
|
|
|
@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
|
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "<|image|>How many hotdogs are in this image? ",
|
prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
|
||||||
images: [][]byte{imgBuf},
|
images: [][]byte{imgBuf},
|
||||||
aspectRatioID: 1,
|
aspectRatioID: 1,
|
||||||
},
|
},
|
||||||
|
@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
|
prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
|
||||||
images: [][]byte{imgBuf},
|
images: [][]byte{imgBuf},
|
||||||
aspectRatioID: 1,
|
aspectRatioID: 1,
|
||||||
},
|
},
|
||||||
|
@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
|
prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
|
||||||
images: [][]byte{imgBuf2},
|
images: [][]byte{imgBuf, imgBuf2},
|
||||||
aspectRatioID: 1,
|
aspectRatioID: 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "Which ones have mustard?"},
|
{Role: "user", Content: "Which ones have mustard?"},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
|
prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
|
||||||
images: [][]byte{imgBuf},
|
images: [][]byte{imgBuf},
|
||||||
aspectRatioID: 1,
|
aspectRatioID: 1,
|
||||||
},
|
},
|
||||||
|
|
|
@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID}
|
images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID}
|
||||||
} else {
|
} else {
|
||||||
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||||
}
|
}
|
||||||
|
@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, i := range images {
|
for _, i := range images {
|
||||||
|
imgPrompt := ""
|
||||||
if isMllama {
|
if isMllama {
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
|
imgPrompt = "<|image|>"
|
||||||
} else {
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
||||||
}
|
}
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
|
||||||
}
|
}
|
||||||
|
|
||||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
|
|
Loading…
Reference in a new issue