diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7f7abaa..f46d741 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -19,6 +19,9 @@ class Llama: ): self.model_path = model_path + self.last_n = 64 + self.max_chunk_size = 32 + self.params = llama_cpp.llama_context_default_params() self.params.n_ctx = n_ctx self.params.n_parts = n_parts @@ -59,21 +62,32 @@ class Llama: self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True ) - if prompt_tokens + max_tokens > self.params.n_ctx: + if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx): raise ValueError( f"Requested tokens exceed context window of {self.params.n_ctx}" ) - for i in range(prompt_tokens): - llama_cpp.llama_eval( - self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads + # Process prompt in chunks to avoid running out of memory + for i in range(0, prompt_tokens, self.max_chunk_size): + chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)] + rc = llama_cpp.llama_eval( + self.ctx, + (llama_cpp.llama_token * len(chunk))(*chunk), + len(chunk), + max(0, i - 1), + self.n_threads, ) + if rc != 0: + raise RuntimeError(f"Failed to evaluate prompt: {rc}") for i in range(max_tokens): + tokens_seen = prompt_tokens + completion_tokens + last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)] + token = llama_cpp.llama_sample_top_p_top_k( self.ctx, - self.tokens, - prompt_tokens + completion_tokens, + (llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens), + len(last_n_tokens), top_k=top_k, top_p=top_p, temp=temperature, @@ -82,7 +96,6 @@ class Llama: if token == llama_cpp.llama_token_eos(): finish_reason = "stop" break - # text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8") text += llama_cpp.llama_token_to_str(self.ctx, token) self.tokens[prompt_tokens + i] = token completion_tokens += 1 @@ -96,7 +109,7 @@ class Llama: llama_cpp.llama_eval( self.ctx, - (llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]), + (llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]), 1, prompt_tokens + completion_tokens, self.n_threads,