feat(low-level-api): Improve API static type-safety and performance (#1205)
This commit is contained in:
parent
0f8aa4ab5c
commit
7f51b6071f
5 changed files with 858 additions and 743 deletions
|
@ -108,7 +108,7 @@ class _LlamaModel:
|
||||||
scale,
|
scale,
|
||||||
path_base_model.encode("utf-8")
|
path_base_model.encode("utf-8")
|
||||||
if path_base_model is not None
|
if path_base_model is not None
|
||||||
else llama_cpp.c_char_p(0),
|
else ctypes.c_char_p(0),
|
||||||
n_threads,
|
n_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -303,8 +303,8 @@ class _LlamaContext:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert batch.batch is not None
|
assert batch.batch is not None
|
||||||
return_code = llama_cpp.llama_decode(
|
return_code = llama_cpp.llama_decode(
|
||||||
ctx=self.ctx,
|
self.ctx,
|
||||||
batch=batch.batch,
|
batch.batch,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"llama_decode returned {return_code}")
|
raise RuntimeError(f"llama_decode returned {return_code}")
|
||||||
|
@ -493,7 +493,7 @@ class _LlamaBatch:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
||||||
):
|
):
|
||||||
self.n_tokens = n_tokens
|
self._n_tokens = n_tokens
|
||||||
self.embd = embd
|
self.embd = embd
|
||||||
self.n_seq_max = n_seq_max
|
self.n_seq_max = n_seq_max
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
@ -502,7 +502,7 @@ class _LlamaBatch:
|
||||||
|
|
||||||
self.batch = None
|
self.batch = None
|
||||||
self.batch = llama_cpp.llama_batch_init(
|
self.batch = llama_cpp.llama_batch_init(
|
||||||
self.n_tokens, self.embd, self.n_seq_max
|
self._n_tokens, self.embd, self.n_seq_max
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -570,12 +570,13 @@ class _LlamaTokenDataArray:
|
||||||
self.candidates.data = self.candidates_data.ctypes.data_as(
|
self.candidates.data = self.candidates_data.ctypes.data_as(
|
||||||
llama_cpp.llama_token_data_p
|
llama_cpp.llama_token_data_p
|
||||||
)
|
)
|
||||||
self.candidates.sorted = llama_cpp.c_bool(False)
|
self.candidates.sorted = ctypes.c_bool(False)
|
||||||
self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
|
self.candidates.size = ctypes.c_size_t(self.n_vocab)
|
||||||
|
|
||||||
|
|
||||||
# Python wrappers over common/common
|
# Python wrappers over common/common
|
||||||
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
|
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
|
||||||
|
assert model.model is not None
|
||||||
n_tokens = len(text) + 1 if add_bos else len(text)
|
n_tokens = len(text) + 1 if add_bos else len(text)
|
||||||
result = (llama_cpp.llama_token * n_tokens)()
|
result = (llama_cpp.llama_token * n_tokens)()
|
||||||
n_tokens = llama_cpp.llama_tokenize(
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
|
|
|
@ -1818,7 +1818,7 @@ class Llama:
|
||||||
self.input_ids = state.input_ids.copy()
|
self.input_ids = state.input_ids.copy()
|
||||||
self.n_tokens = state.n_tokens
|
self.n_tokens = state.n_tokens
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
|
LLamaStateArrayType = ctypes.c_uint8 * state_size
|
||||||
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
||||||
|
|
||||||
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
|
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -5,21 +5,15 @@ from ctypes import (
|
||||||
c_bool,
|
c_bool,
|
||||||
c_char_p,
|
c_char_p,
|
||||||
c_int,
|
c_int,
|
||||||
c_int8,
|
|
||||||
c_int32,
|
|
||||||
c_uint8,
|
c_uint8,
|
||||||
c_uint32,
|
|
||||||
c_size_t,
|
|
||||||
c_float,
|
c_float,
|
||||||
c_double,
|
|
||||||
c_void_p,
|
c_void_p,
|
||||||
POINTER,
|
POINTER,
|
||||||
_Pointer, # type: ignore
|
_Pointer, # type: ignore
|
||||||
Structure,
|
Structure,
|
||||||
Array,
|
|
||||||
)
|
)
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Union
|
from typing import List, Union, NewType, Optional
|
||||||
|
|
||||||
import llama_cpp.llama_cpp as llama_cpp
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
|
|
||||||
|
@ -67,7 +61,7 @@ def _load_shared_library(lib_base_name: str):
|
||||||
for _lib_path in _lib_paths:
|
for _lib_path in _lib_paths:
|
||||||
if _lib_path.exists():
|
if _lib_path.exists():
|
||||||
try:
|
try:
|
||||||
return ctypes.CDLL(str(_lib_path), **cdll_args)
|
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||||
|
|
||||||
|
@ -88,7 +82,8 @@ _libllava = _load_shared_library(_libllava_base_name)
|
||||||
################################################
|
################################################
|
||||||
|
|
||||||
# struct clip_ctx;
|
# struct clip_ctx;
|
||||||
clip_ctx_p = c_void_p
|
clip_ctx_p = NewType("clip_ctx_p", int)
|
||||||
|
clip_ctx_p_ctypes = c_void_p
|
||||||
|
|
||||||
# struct llava_image_embed {
|
# struct llava_image_embed {
|
||||||
# float * embed;
|
# float * embed;
|
||||||
|
@ -102,43 +97,48 @@ class llava_image_embed(Structure):
|
||||||
|
|
||||||
# /** sanity check for clip <-> llava embed size match */
|
# /** sanity check for clip <-> llava embed size match */
|
||||||
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||||
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool:
|
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
|
||||||
return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip)
|
...
|
||||||
|
|
||||||
_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p]
|
llava_validate_embed_size = _libllava.llava_validate_embed_size
|
||||||
_libllava.llava_validate_embed_size.restype = c_bool
|
llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
|
||||||
|
llava_validate_embed_size.restype = c_bool
|
||||||
|
|
||||||
# /** build an image embed from image file bytes */
|
# /** build an image embed from image file bytes */
|
||||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||||
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]":
|
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
|
||||||
return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length)
|
...
|
||||||
|
|
||||||
_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int]
|
llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
|
||||||
_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
|
llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
|
||||||
|
llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
|
||||||
|
|
||||||
# /** build an image embed from a path to an image filename */
|
# /** build an image embed from a path to an image filename */
|
||||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||||
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]":
|
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
|
||||||
return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path)
|
...
|
||||||
|
|
||||||
_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p]
|
llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
|
||||||
_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
|
llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
|
||||||
|
llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
|
||||||
|
|
||||||
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||||
# /** free an embedding made with llava_image_embed_make_* */
|
# /** free an embedding made with llava_image_embed_make_* */
|
||||||
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"):
|
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
|
||||||
return _libllava.llava_image_embed_free(embed)
|
...
|
||||||
|
|
||||||
_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
|
llava_image_embed_free = _libllava.llava_image_embed_free
|
||||||
_libllava.llava_image_embed_free.restype = None
|
llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
|
||||||
|
llava_image_embed_free.restype = None
|
||||||
|
|
||||||
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||||
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||||
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool:
|
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
|
||||||
return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past)
|
...
|
||||||
|
|
||||||
_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)]
|
llava_eval_image_embed = _libllava.llava_eval_image_embed
|
||||||
_libllava.llava_eval_image_embed.restype = c_bool
|
llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
|
||||||
|
llava_eval_image_embed.restype = c_bool
|
||||||
|
|
||||||
|
|
||||||
################################################
|
################################################
|
||||||
|
@ -148,16 +148,18 @@ _libllava.llava_eval_image_embed.restype = c_bool
|
||||||
|
|
||||||
# /** load mmproj model */
|
# /** load mmproj model */
|
||||||
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||||
def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p:
|
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
|
||||||
return _libllava.clip_model_load(fname, verbosity)
|
...
|
||||||
|
|
||||||
_libllava.clip_model_load.argtypes = [c_char_p, c_int]
|
clip_model_load = _libllava.clip_model_load
|
||||||
_libllava.clip_model_load.restype = clip_ctx_p
|
clip_model_load.argtypes = [c_char_p, c_int]
|
||||||
|
clip_model_load.restype = clip_ctx_p_ctypes
|
||||||
|
|
||||||
# /** free mmproj model */
|
# /** free mmproj model */
|
||||||
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||||
def clip_free(ctx: clip_ctx_p):
|
def clip_free(ctx: clip_ctx_p, /):
|
||||||
return _libllava.clip_free(ctx)
|
...
|
||||||
|
|
||||||
_libllava.clip_free.argtypes = [clip_ctx_p]
|
clip_free = _libllava.clip_free
|
||||||
_libllava.clip_free.restype = None
|
clip_free.argtypes = [clip_ctx_p_ctypes]
|
||||||
|
clip_free.restype = None
|
||||||
|
|
|
@ -54,7 +54,7 @@ def mock_llama(monkeypatch):
|
||||||
output_tokens = llama.tokenize(
|
output_tokens = llama.tokenize(
|
||||||
output_text.encode("utf-8"), add_bos=True, special=True
|
output_text.encode("utf-8"), add_bos=True, special=True
|
||||||
)
|
)
|
||||||
logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
|
logits = (ctypes.c_float * (n_vocab * n_ctx))(-100.0)
|
||||||
for i in range(n_ctx):
|
for i in range(n_ctx):
|
||||||
output_idx = i + 1 # logits for first tokens predict second token
|
output_idx = i + 1 # logits for first tokens predict second token
|
||||||
if output_idx < len(output_tokens):
|
if output_idx < len(output_tokens):
|
||||||
|
@ -90,9 +90,9 @@ def mock_llama(monkeypatch):
|
||||||
assert n > 0, "mock_llama_decode not called"
|
assert n > 0, "mock_llama_decode not called"
|
||||||
assert last_n_tokens > 0, "mock_llama_decode not called"
|
assert last_n_tokens > 0, "mock_llama_decode not called"
|
||||||
# Return view of logits for last_n_tokens
|
# Return view of logits for last_n_tokens
|
||||||
return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
|
return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address(
|
||||||
ctypes.addressof(logits)
|
ctypes.addressof(logits)
|
||||||
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
|
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float)
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
|
|
Loading…
Reference in a new issue