llama.cpp/llama_cpp/llava_cpp.py

164 lines
6.5 KiB
Python
Raw Normal View History

import sys
import os
import ctypes
from ctypes import (
c_bool,
c_char_p,
c_int,
c_int8,
c_int32,
c_uint8,
c_uint32,
c_size_t,
c_float,
c_double,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
Array,
)
import pathlib
from typing import List, Union
import llama_cpp.llama_cpp as llama_cpp
# Load the library
def _load_shared_library(lib_base_name: str):
# Construct the paths to the possible shared library names
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
# Searching for the library in the current directory under the name "libllama" (default name
# for llamacpp) and "llama" (default name for this repo)
_lib_paths: List[pathlib.Path] = []
# Determine the file extension based on the platform
if sys.platform.startswith("linux"):
_lib_paths += [
_base_path / f"lib{lib_base_name}.so",
]
elif sys.platform == "darwin":
_lib_paths += [
_base_path / f"lib{lib_base_name}.so",
_base_path / f"lib{lib_base_name}.dylib",
]
elif sys.platform == "win32":
_lib_paths += [
_base_path / f"{lib_base_name}.dll",
_base_path / f"lib{lib_base_name}.dll",
]
else:
raise RuntimeError("Unsupported platform")
2023-11-28 09:55:21 +00:00
if "LLAVA_CPP_LIB" in os.environ:
lib_base_name = os.environ["LLAVA_CPP_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
_lib_paths = [_lib.resolve()]
cdll_args = dict() # type: ignore
# Add the library directory to the DLL search path on Windows (if needed)
if sys.platform == "win32" and sys.version_info >= (3, 8):
os.add_dll_directory(str(_base_path))
if "CUDA_PATH" in os.environ:
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
# Try to load the shared library, handling potential errors
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args)
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# Specify the base name of the shared library to load
_libllava_base_name = "llava"
# Load the library
_libllava = _load_shared_library(_libllava_base_name)
################################################
# llava.h
################################################
# struct clip_ctx;
clip_ctx_p = c_void_p
# struct llava_image_embed {
# float * embed;
# int n_image_pos;
# };
class llava_image_embed(Structure):
_fields_ = [
("embed", POINTER(c_float)),
("n_image_pos", c_int),
]
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool:
return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip)
_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p]
_libllava.llava_validate_embed_size.restype = c_bool
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]":
return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length)
_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int]
_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]":
return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path)
_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p]
_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"):
return _libllava.llava_image_embed_free(embed)
_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
_libllava.llava_image_embed_free.restype = None
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool:
return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past)
_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)]
_libllava.llava_eval_image_embed.restype = c_bool
################################################
# clip.h
################################################
# /** load mmproj model */
2024-02-14 08:47:21 +00:00
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p:
return _libllava.clip_model_load(fname, verbosity)
_libllava.clip_model_load.argtypes = [c_char_p, c_int]
_libllava.clip_model_load.restype = clip_ctx_p
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
def clip_free(ctx: clip_ctx_p):
return _libllava.clip_free(ctx)
_libllava.clip_free.argtypes = [clip_ctx_p]
_libllava.clip_free.restype = None