fix: Always set logits_all = True when using speculative decoding

This commit is contained in:
Andrei Betlen 2024-02-12 16:19:05 -05:00
parent 153a0049d9
commit cb791716b4

View file

@ -281,7 +281,7 @@ class Llama:
) )
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = logits_all self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
self.context_params.embedding = embedding self.context_params.embedding = embedding
self.context_params.offload_kqv = offload_kqv self.context_params.offload_kqv = offload_kqv