fix: module 'llama_cpp.llama_cpp' has no attribute 'c_uint8'

This commit is contained in:
Andrei Betlen 2024-02-23 11:24:53 -05:00
parent 427d816ebf
commit db776a885c

View file

@ -5,8 +5,10 @@ import sys
import uuid
import time
import json
import ctypes
import fnmatch
import multiprocessing
from typing import (
List,
Optional,
@ -20,7 +22,6 @@ from typing import (
from collections import deque
from pathlib import Path
import ctypes
from llama_cpp.llama_types import List
@ -1789,7 +1790,7 @@ class Llama:
state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
if self.verbose:
print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
llama_state = (ctypes.c_uint8 * int(state_size))()
if self.verbose:
print("Llama.save_state: allocated state", file=sys.stderr)
n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
@ -1797,7 +1798,7 @@ class Llama:
print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
if int(n_bytes) > int(state_size):
raise RuntimeError("Failed to copy llama state data")
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
if self.verbose:
print(