fix: llava this function takes at least 4 arguments (0 given)

This commit is contained in:
Andrei Betlen 2024-02-26 11:03:20 -05:00
parent 34111788fe
commit 8383a9e562

View file

@ -1840,7 +1840,7 @@ def functionary_v1_v2_chat_handler(
class Llava15ChatHandler:
_clip_free = None
def __init__(self, clip_model_path: str, verbose: bool = False):
def __init__(self, clip_model_path: str, verbose: bool = False):
import llama_cpp.llava_cpp as llava_cpp
self._llava_cpp = llava_cpp
@ -1957,10 +1957,10 @@ class Llava15ChatHandler:
with suppress_stdout_stderr(disable=self.verbose):
embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
self.clip_ctx,
llama.context_params.n_threads,
c_ubyte_ptr,
length=len(image_bytes),
)
)
try:
@ -1968,10 +1968,10 @@ class Llava15ChatHandler:
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
ctx_llama=llama.ctx,
embed=embed,
n_batch=llama.n_batch,
n_past=n_past_p,
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value