fix: Remove deprecated cfg sampling functions
This commit is contained in:
parent
727d60c28a
commit
8c71725d53
2 changed files with 1 additions and 50 deletions
|
@ -357,21 +357,6 @@ class _LlamaContext:
|
||||||
penalty_present,
|
penalty_present,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_classifier_free_guidance(
|
|
||||||
self,
|
|
||||||
candidates: "_LlamaTokenDataArray",
|
|
||||||
guidance_ctx: "_LlamaContext",
|
|
||||||
scale: float,
|
|
||||||
):
|
|
||||||
assert self.ctx is not None
|
|
||||||
assert guidance_ctx.ctx is not None
|
|
||||||
llama_cpp.llama_sample_classifier_free_guidance(
|
|
||||||
self.ctx,
|
|
||||||
llama_cpp.byref(candidates.candidates),
|
|
||||||
guidance_ctx.ctx,
|
|
||||||
scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
llama_cpp.llama_sample_softmax(
|
llama_cpp.llama_sample_softmax(
|
||||||
|
@ -720,7 +705,7 @@ class _LlamaSamplingContext:
|
||||||
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
|
self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
|
||||||
):
|
):
|
||||||
n_vocab = ctx_main.model.n_vocab()
|
n_vocab = ctx_main.model.n_vocab()
|
||||||
id: int = 0
|
id: int = 0
|
||||||
|
@ -741,11 +726,6 @@ class _LlamaSamplingContext:
|
||||||
) # TODO: Only create this once
|
) # TODO: Only create this once
|
||||||
token_data_array.copy_logits(logits_array)
|
token_data_array.copy_logits(logits_array)
|
||||||
|
|
||||||
if ctx_cfg is not None:
|
|
||||||
ctx_main.sample_classifier_free_guidance(
|
|
||||||
token_data_array, ctx_cfg, self.params.cfg_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
# apply penalties
|
# apply penalties
|
||||||
if len(self.prev) > 0:
|
if len(self.prev) > 0:
|
||||||
nl_token = ctx_main.model.token_nl()
|
nl_token = ctx_main.model.token_nl()
|
||||||
|
|
|
@ -2129,35 +2129,6 @@ def llama_sample_apply_guidance(
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
|
|
||||||
# struct llama_context * ctx,
|
|
||||||
# llama_token_data_array * candidates,
|
|
||||||
# struct llama_context * guidance_ctx,
|
|
||||||
# float scale),
|
|
||||||
# "use llama_sample_apply_guidance() instead");
|
|
||||||
@ctypes_function(
|
|
||||||
"llama_sample_classifier_free_guidance",
|
|
||||||
[
|
|
||||||
llama_context_p_ctypes,
|
|
||||||
llama_token_data_array_p,
|
|
||||||
llama_context_p_ctypes,
|
|
||||||
ctypes.c_float,
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
def llama_sample_classifier_free_guidance(
|
|
||||||
ctx: llama_context_p,
|
|
||||||
candidates: Union[
|
|
||||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
|
||||||
],
|
|
||||||
guidance_ctx: llama_context_p,
|
|
||||||
scale: Union[ctypes.c_float, float],
|
|
||||||
/,
|
|
||||||
):
|
|
||||||
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
# LLAMA_API void llama_sample_softmax(
|
# LLAMA_API void llama_sample_softmax(
|
||||||
# struct llama_context * ctx,
|
# struct llama_context * ctx,
|
||||||
|
|
Loading…
Reference in a new issue