Update llama_cpp and add kv_cache api support
This commit is contained in:
parent
74061b209d
commit
1ed8cd023d
2 changed files with 32 additions and 2 deletions
|
@ -1,6 +1,6 @@
|
|||
import ctypes
|
||||
|
||||
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array
|
||||
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
|
||||
|
||||
import pathlib
|
||||
from itertools import chain
|
||||
|
@ -101,6 +101,36 @@ def llama_model_quantize(
|
|||
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
|
||||
_lib.llama_model_quantize.restype = c_int
|
||||
|
||||
# Returns the KV cache that will contain the context for the
|
||||
# ongoing prediction with the model.
|
||||
def llama_get_kv_cache(ctx: llama_context_p):
|
||||
return _lib.llama_get_kv_cache(ctx)
|
||||
|
||||
_lib.llama_get_kv_cache.argtypes = [llama_context_p]
|
||||
_lib.llama_get_kv_cache.restype = POINTER(c_uint8)
|
||||
|
||||
# Returns the size of the KV cache
|
||||
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
|
||||
return _lib.llama_get_kv_cache_size(ctx)
|
||||
|
||||
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
|
||||
_lib.llama_get_kv_cache_size.restype = c_size_t
|
||||
|
||||
# Returns the number of tokens in the KV cache
|
||||
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
|
||||
return _lib.llama_get_kv_cache_token_count(ctx)
|
||||
|
||||
_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
|
||||
_lib.llama_get_kv_cache_token_count.restype = c_int
|
||||
|
||||
|
||||
# Sets the KV cache containing the current context for the model
|
||||
def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int):
|
||||
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
|
||||
|
||||
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
|
||||
_lib.llama_set_kv_cache.restype = None
|
||||
|
||||
|
||||
# Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
# tokens + n_tokens is the provided batch of new tokens to process
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit d0a7f742e76bb48c0bd852f0b3bf09ec0b75b200
|
||||
Subproject commit d8d4e865cd481b18f10508ffee35db903767ef5c
|
Loading…
Reference in a new issue