From d318cc8b83981833fbab2314d97af21104e52d99 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 14 Mar 2024 09:17:41 -0400 Subject: [PATCH] fix: Set default pooling_type to mean, check for null pointer. --- llama_cpp/llama.py | 10 ++++++++-- llama_cpp/llama_cpp.py | 1 - 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index aabbb7e..18db1a0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -79,6 +79,7 @@ class Llama: n_threads: Optional[int] = None, n_threads_batch: Optional[int] = None, rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_MEAN, rope_freq_base: float = 0.0, rope_freq_scale: float = 0.0, yarn_ext_factor: float = -1.0, @@ -151,6 +152,7 @@ class Llama: n_threads: Number of threads to use for generation n_threads_batch: Number of threads to use for batch processing rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054 + pooling_type: Pooling type, from `enum llama_pooling_type`. rope_freq_base: RoPE base frequency, 0 = from model rope_freq_scale: RoPE frequency scaling factor, 0 = from model yarn_ext_factor: YaRN extrapolation mix factor, negative = from model @@ -271,6 +273,7 @@ class Llama: if rope_scaling_type is not None else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ) + self.context_params.pooling_type = pooling_type self.context_params.rope_freq_base = ( rope_freq_base if rope_freq_base != 0.0 else 0 ) @@ -814,9 +817,12 @@ class Llama: # store embeddings for i in range(n_seq): - embedding: List[float] = llama_cpp.llama_get_embeddings_seq( + ptr = llama_cpp.llama_get_embeddings_seq( self._ctx.ctx, i - )[:n_embd] + ) + if not ptr: + raise RuntimeError("Failed to get embeddings from sequence pooling type is not set") + embedding: List[float] = ptr[:n_embd] if normalize: norm = float(np.linalg.norm(embedding)) embedding = [v / norm for v in embedding] diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 884a063..dc45fd9 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -579,7 +579,6 @@ class llama_model_params(ctypes.Structure): # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - # // Abort callback # // if it returns true, execution of llama_decode() will be aborted # // currently works only with CPU execution