fix: Set default pooling_type to mean, check for null pointer.
This commit is contained in:
parent
dd0ee56217
commit
d318cc8b83
2 changed files with 8 additions and 3 deletions
|
@ -79,6 +79,7 @@ class Llama:
|
|||
n_threads: Optional[int] = None,
|
||||
n_threads_batch: Optional[int] = None,
|
||||
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_MEAN,
|
||||
rope_freq_base: float = 0.0,
|
||||
rope_freq_scale: float = 0.0,
|
||||
yarn_ext_factor: float = -1.0,
|
||||
|
@ -151,6 +152,7 @@ class Llama:
|
|||
n_threads: Number of threads to use for generation
|
||||
n_threads_batch: Number of threads to use for batch processing
|
||||
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
pooling_type: Pooling type, from `enum llama_pooling_type`.
|
||||
rope_freq_base: RoPE base frequency, 0 = from model
|
||||
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
|
||||
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
|
||||
|
@ -271,6 +273,7 @@ class Llama:
|
|||
if rope_scaling_type is not None
|
||||
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||
)
|
||||
self.context_params.pooling_type = pooling_type
|
||||
self.context_params.rope_freq_base = (
|
||||
rope_freq_base if rope_freq_base != 0.0 else 0
|
||||
)
|
||||
|
@ -814,9 +817,12 @@ class Llama:
|
|||
|
||||
# store embeddings
|
||||
for i in range(n_seq):
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
|
||||
ptr = llama_cpp.llama_get_embeddings_seq(
|
||||
self._ctx.ctx, i
|
||||
)[:n_embd]
|
||||
)
|
||||
if not ptr:
|
||||
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
|
||||
embedding: List[float] = ptr[:n_embd]
|
||||
if normalize:
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
embedding = [v / norm for v in embedding]
|
||||
|
|
|
@ -579,7 +579,6 @@ class llama_model_params(ctypes.Structure):
|
|||
# bool embeddings; // if true, extract embeddings (together with logits)
|
||||
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
|
||||
|
||||
# // Abort callback
|
||||
# // if it returns true, execution of llama_decode() will be aborted
|
||||
# // currently works only with CPU execution
|
||||
|
|
Loading…
Reference in a new issue