fix: llava this function takes at least 4 arguments (0 given)
This commit is contained in:
parent
34111788fe
commit
8383a9e562
1 changed files with 9 additions and 9 deletions
|
@ -1840,7 +1840,7 @@ def functionary_v1_v2_chat_handler(
|
|||
class Llava15ChatHandler:
|
||||
_clip_free = None
|
||||
|
||||
def __init__(self, clip_model_path: str, verbose: bool = False):
|
||||
def __init__(self, clip_model_path: str, verbose: bool = False):
|
||||
import llama_cpp.llava_cpp as llava_cpp
|
||||
|
||||
self._llava_cpp = llava_cpp
|
||||
|
@ -1957,10 +1957,10 @@ class Llava15ChatHandler:
|
|||
with suppress_stdout_stderr(disable=self.verbose):
|
||||
embed = (
|
||||
self._llava_cpp.llava_image_embed_make_with_bytes(
|
||||
ctx_clip=self.clip_ctx,
|
||||
n_threads=llama.context_params.n_threads,
|
||||
image_bytes=c_ubyte_ptr,
|
||||
image_bytes_length=len(image_bytes),
|
||||
self.clip_ctx,
|
||||
llama.context_params.n_threads,
|
||||
c_ubyte_ptr,
|
||||
length=len(image_bytes),
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
@ -1968,10 +1968,10 @@ class Llava15ChatHandler:
|
|||
n_past_p = ctypes.pointer(n_past)
|
||||
with suppress_stdout_stderr(disable=self.verbose):
|
||||
self._llava_cpp.llava_eval_image_embed(
|
||||
ctx_llama=llama.ctx,
|
||||
embed=embed,
|
||||
n_batch=llama.n_batch,
|
||||
n_past=n_past_p,
|
||||
llama.ctx,
|
||||
embed,
|
||||
llama.n_batch,
|
||||
n_past_p,
|
||||
)
|
||||
assert llama.n_ctx() >= n_past.value
|
||||
llama.n_tokens = n_past.value
|
||||
|
|
Loading…
Reference in a new issue