diff --git a/tests/test_llama.py b/tests/test_llama.py index d682eec..396ed31 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -160,55 +160,18 @@ def test_llama_pickle(): assert llama.detokenize(llama.tokenize(text)) == text -def test_utf8(mock_llama, monkeypatch): +def test_utf8(mock_llama): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True) - n_ctx = llama.n_ctx() - n_vocab = llama.n_vocab() output_text = "😀" - output_tokens = llama.tokenize( - output_text.encode("utf-8"), add_bos=True, special=True - ) - token_eos = llama.token_eos() - n = 0 - - def reset(): - nonlocal n - llama.reset() - n = 0 - - ## Set up mock function - def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch): - nonlocal n - assert batch.n_tokens > 0 - assert llama.n_tokens == n - n += batch.n_tokens - return 0 - - def mock_get_logits(*args, **kwargs): - size = n_vocab * n_ctx - return (llama_cpp.c_float * size)() - - def mock_sample(*args, **kwargs): - nonlocal n - if n <= len(output_tokens): - return output_tokens[n - 1] - else: - return token_eos - - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) - monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) - monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample) ## Test basic completion with utf8 multibyte - # mock_llama(llama, output_text) - reset() + mock_llama(llama, output_text) completion = llama.create_completion("", max_tokens=4) assert completion["choices"][0]["text"] == output_text ## Test basic completion with incomplete utf8 multibyte - # mock_llama(llama, output_text) - reset() + mock_llama(llama, output_text) completion = llama.create_completion("", max_tokens=1) assert completion["choices"][0]["text"] == ""