bf4018b9ec
* Refine llama.cpp vendoring workflow tools Switch from the sync.sh over to make based tooling * Run new make sync and patch flow
51 lines
2.1 KiB
Diff
51 lines
2.1 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: Michael Yang <mxyng@pm.me>
|
|
Date: Mon, 16 Sep 2024 15:53:14 -0700
|
|
Subject: [PATCH] embeddings
|
|
|
|
---
|
|
src/llama.cpp | 15 +++++++++------
|
|
1 file changed, 9 insertions(+), 6 deletions(-)
|
|
|
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
|
index 800dfb95..a639522d 100644
|
|
--- a/src/llama.cpp
|
|
+++ b/src/llama.cpp
|
|
@@ -16920,7 +16920,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
|
const auto n_embd = hparams.n_embd;
|
|
|
|
// TODO: use a per-batch flag for logits presence instead
|
|
- const bool has_logits = !cparams.embeddings;
|
|
+ const bool has_logits = cparams.causal_attn;
|
|
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
|
|
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
|
@@ -17192,20 +17192,23 @@ static int llama_decode_internal(
|
|
// no output
|
|
res = nullptr;
|
|
embd = nullptr;
|
|
- } else if (cparams.embeddings) {
|
|
- res = nullptr; // do not extract logits for embedding case
|
|
- embd = nullptr;
|
|
+ }
|
|
+
|
|
+ if (cparams.embeddings) {
|
|
for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
|
|
+ embd = ggml_graph_node(gf, i);
|
|
if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
|
|
- embd = ggml_graph_node(gf, i);
|
|
break;
|
|
}
|
|
}
|
|
- GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
|
|
} else {
|
|
embd = nullptr; // do not extract embeddings when not needed
|
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
|
}
|
|
+
|
|
+ if (!cparams.causal_attn) {
|
|
+ res = nullptr; // do not extract logits when not needed
|
|
+ }
|
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
|
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|