llama.cpp/tests/test_llama_speculative.py
Andrei fb762a6041
Add speculative decoding (#1120)
* Add draft model param to llama class, implement basic prompt lookup decoding draft model

* Use samplingcontext for sampling

* Use 1d array

* Use draft model for sampling

* Fix dumb mistake

* Allow for later extensions to the LlamaDraftModel api

* Cleanup

* Adaptive candidate prediction

* Update implementation to match hf transformers

* Tuning

* Fix bug where last token was not used for ngram prediction

* Remove heuristic for num_pred_tokens (no benefit)

* fix: n_candidates bug.

* Add draft_model_num_pred_tokens server setting

* Cleanup

* Update README
2024-01-31 14:08:14 -05:00

16 lines
696 B
Python

import numpy as np
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
def test_find_candidate_pred_tokens():
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
# Test Case 1: Matching ngram is found
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result1, np.array([1, 2]))
# Test Case 2: Matching ngram is not found
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result2, np.array([]))