diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 2724edd..9979a67 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -470,6 +470,7 @@ class llama_model_params(Structure): # bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) # bool embedding; // embedding mode only # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU +# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) # }; class llama_context_params(Structure): """Parameters for llama_context @@ -496,6 +497,7 @@ class llama_context_params(Structure): logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) embedding (bool): embedding mode only offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU + do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) """ _fields_ = [ @@ -520,6 +522,7 @@ class llama_context_params(Structure): ("logits_all", c_bool), ("embedding", c_bool), ("offload_kqv", c_bool), + ("do_pooling", c_bool), ] @@ -1699,6 +1702,21 @@ _lib.llama_get_embeddings.argtypes = [llama_context_p] _lib.llama_get_embeddings.restype = c_float_p +# // Get the embeddings for the ith sequence +# // llama_get_embeddings(ctx) + i*n_embd +# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +def llama_get_embeddings_ith( + ctx: llama_context_p, i: Union[c_int32, int] +): # type: (...) -> Array[float] # type: ignore + """Get the embeddings for the ith sequence + llama_get_embeddings(ctx) + i*n_embd""" + return _lib.llama_get_embeddings_ith(ctx, i) + + +_lib.llama_get_embeddings_ith.argtypes = [llama_context_p, c_int32] +_lib.llama_get_embeddings_ith.restype = c_float_p + + # // # // Vocab # // diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 895407f..ea9c8e1 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 895407f31b358e3d9335e847d13f033491ec8a5b +Subproject commit ea9c8e11436ad50719987fa23a289c74b7b40d40