misc: llava_cpp use ctypes function decorator for binding
This commit is contained in:
parent
8383a9e562
commit
44558cbd7a
1 changed files with 34 additions and 28 deletions
|
@ -1,6 +1,7 @@
|
|||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import functools
|
||||
from ctypes import (
|
||||
c_bool,
|
||||
c_char_p,
|
||||
|
@ -13,7 +14,7 @@ from ctypes import (
|
|||
Structure,
|
||||
)
|
||||
import pathlib
|
||||
from typing import List, Union, NewType, Optional
|
||||
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any
|
||||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
|
@ -76,6 +77,31 @@ _libllava_base_name = "llava"
|
|||
# Load the library
|
||||
_libllava = _load_shared_library(_libllava_base_name)
|
||||
|
||||
# ctypes helper
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
|
||||
def ctypes_function(
|
||||
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
|
||||
):
|
||||
def decorator(f: F) -> F:
|
||||
if enabled:
|
||||
func = getattr(lib, name)
|
||||
func.argtypes = argtypes
|
||||
func.restype = restype
|
||||
functools.wraps(f)(func)
|
||||
return func
|
||||
else:
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
return ctypes_function
|
||||
|
||||
|
||||
ctypes_function = ctypes_function_for_shared_library(_libllava)
|
||||
|
||||
|
||||
################################################
|
||||
# llava.h
|
||||
|
@ -97,49 +123,35 @@ class llava_image_embed(Structure):
|
|||
|
||||
# /** sanity check for clip <-> llava embed size match */
|
||||
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool)
|
||||
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
|
||||
...
|
||||
|
||||
llava_validate_embed_size = _libllava.llava_validate_embed_size
|
||||
llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
|
||||
llava_validate_embed_size.restype = c_bool
|
||||
|
||||
# /** build an image embed from image file bytes */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed))
|
||||
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
|
||||
...
|
||||
|
||||
llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
|
||||
llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
|
||||
llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
|
||||
|
||||
# /** build an image embed from a path to an image filename */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed))
|
||||
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
|
||||
...
|
||||
|
||||
llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
|
||||
llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
|
||||
llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
|
||||
|
||||
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||
# /** free an embedding made with llava_image_embed_make_* */
|
||||
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
|
||||
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
|
||||
...
|
||||
|
||||
llava_image_embed_free = _libllava.llava_image_embed_free
|
||||
llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
|
||||
llava_image_embed_free.restype = None
|
||||
|
||||
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool)
|
||||
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
|
||||
...
|
||||
|
||||
llava_eval_image_embed = _libllava.llava_eval_image_embed
|
||||
llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
|
||||
llava_eval_image_embed.restype = c_bool
|
||||
|
||||
|
||||
################################################
|
||||
# clip.h
|
||||
|
@ -148,18 +160,12 @@ llava_eval_image_embed.restype = c_bool
|
|||
|
||||
# /** load mmproj model */
|
||||
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
|
||||
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
|
||||
...
|
||||
|
||||
clip_model_load = _libllava.clip_model_load
|
||||
clip_model_load.argtypes = [c_char_p, c_int]
|
||||
clip_model_load.restype = clip_ctx_p_ctypes
|
||||
|
||||
# /** free mmproj model */
|
||||
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
|
||||
def clip_free(ctx: clip_ctx_p, /):
|
||||
...
|
||||
|
||||
clip_free = _libllava.clip_free
|
||||
clip_free.argtypes = [clip_ctx_p_ctypes]
|
||||
clip_free.restype = None
|
||||
|
|
Loading…
Add table
Reference in a new issue