Merge branch 'main' of github.com:abetlen/llama_cpp_python into main

This commit is contained in:
Andrei Betlen 2023-06-26 08:50:48 -04:00
commit 3379dc40a1

View file

@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache):
if _key is None: if _key is None:
raise KeyError("Key not found") raise KeyError("Key not found")
value: "LlamaState" = self.cache.pop(_key) # type: ignore value: "LlamaState" = self.cache.pop(_key) # type: ignore
self.cache.push(_key, side="front") # type: ignore # NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value return value
def __contains__(self, key: Sequence[int]) -> bool: def __contains__(self, key: Sequence[int]) -> bool:
@ -168,7 +170,7 @@ class LlamaState:
eval_logits: Deque[List[float]], eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc], input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single], scores: npt.NDArray[np.single],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state: bytes,
llama_state_size: int, llama_state_size: int,
): ):
self.eval_tokens = eval_tokens self.eval_tokens = eval_tokens
@ -1512,7 +1514,7 @@ class Llama:
eval_logits=self.eval_logits.copy(), eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(), scores=self._scores.copy(),
input_ids=self._input_ids.copy(), input_ids=self._input_ids.copy(),
llama_state=llama_state_compact, llama_state=bytes(llama_state_compact),
llama_state_size=n_bytes, llama_state_size=n_bytes,
) )
@ -1523,7 +1525,10 @@ class Llama:
self._scores = state.scores.copy() self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy() self._input_ids = state.input_ids.copy()
state_size = state.llama_state_size state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to set llama state data") raise RuntimeError("Failed to set llama state data")
def n_ctx(self) -> int: def n_ctx(self) -> int: