fb762a6041
* Add draft model param to llama class, implement basic prompt lookup decoding draft model * Use samplingcontext for sampling * Use 1d array * Use draft model for sampling * Fix dumb mistake * Allow for later extensions to the LlamaDraftModel api * Cleanup * Adaptive candidate prediction * Update implementation to match hf transformers * Tuning * Fix bug where last token was not used for ngram prediction * Remove heuristic for num_pred_tokens (no benefit) * fix: n_candidates bug. * Add draft_model_num_pred_tokens server setting * Cleanup * Update README
16 lines
696 B
Python
16 lines
696 B
Python
import numpy as np
|
|
|
|
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
|
|
|
def test_find_candidate_pred_tokens():
|
|
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
|
|
|
|
# Test Case 1: Matching ngram is found
|
|
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
|
|
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
|
|
assert np.array_equal(result1, np.array([1, 2]))
|
|
|
|
# Test Case 2: Matching ngram is not found
|
|
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
|
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
|
|
assert np.array_equal(result2, np.array([]))
|