diff --git a/src/llama.cpp b/src/llama.cpp index 721b8f4e..cfe7ac40 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8420,14 +8420,14 @@ struct llm_build_context { } struct ggml_tensor * build_inp_mean() { - lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max); cb(lctx.inp_mean, "inp_mean", -1); ggml_set_input(lctx.inp_mean); return lctx.inp_mean; } struct ggml_tensor * build_inp_cls() { - lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max); cb(lctx.inp_cls, "inp_cls", -1); ggml_set_input(lctx.inp_cls); return lctx.inp_cls; @@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); float * data = (float *) lctx.inp_mean->data; - memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); + memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; - - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; } - std::vector div(n_tokens, 0.0f); - for (int i = 0; i < n_tokens; ++i) { + std::vector div(cparams.n_seq_max, 0.0f); + for (uint32_t i = 0; i < cparams.n_seq_max; ++i) { const uint64_t s = sum[i]; if (s > 0) { div[i] = 1.0f/float(s); @@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); uint32_t * data = (uint32_t *) lctx.inp_cls->data; - memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls)); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; const llama_pos pos = batch.pos[i]; - - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { data[seq_id] = i; }