Implement prompt batch processing as in main.cpp
This commit is contained in:
parent
a28cb92d8f
commit
e24c581b5a
1 changed files with 21 additions and 8 deletions
|
@ -19,6 +19,9 @@ class Llama:
|
|||
):
|
||||
self.model_path = model_path
|
||||
|
||||
self.last_n = 64
|
||||
self.max_chunk_size = 32
|
||||
|
||||
self.params = llama_cpp.llama_context_default_params()
|
||||
self.params.n_ctx = n_ctx
|
||||
self.params.n_parts = n_parts
|
||||
|
@ -59,21 +62,32 @@ class Llama:
|
|||
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
|
||||
)
|
||||
|
||||
if prompt_tokens + max_tokens > self.params.n_ctx:
|
||||
if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
|
||||
raise ValueError(
|
||||
f"Requested tokens exceed context window of {self.params.n_ctx}"
|
||||
)
|
||||
|
||||
for i in range(prompt_tokens):
|
||||
llama_cpp.llama_eval(
|
||||
self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads
|
||||
# Process prompt in chunks to avoid running out of memory
|
||||
for i in range(0, prompt_tokens, self.max_chunk_size):
|
||||
chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)]
|
||||
rc = llama_cpp.llama_eval(
|
||||
self.ctx,
|
||||
(llama_cpp.llama_token * len(chunk))(*chunk),
|
||||
len(chunk),
|
||||
max(0, i - 1),
|
||||
self.n_threads,
|
||||
)
|
||||
if rc != 0:
|
||||
raise RuntimeError(f"Failed to evaluate prompt: {rc}")
|
||||
|
||||
for i in range(max_tokens):
|
||||
tokens_seen = prompt_tokens + completion_tokens
|
||||
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)]
|
||||
|
||||
token = llama_cpp.llama_sample_top_p_top_k(
|
||||
self.ctx,
|
||||
self.tokens,
|
||||
prompt_tokens + completion_tokens,
|
||||
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
|
||||
len(last_n_tokens),
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
|
@ -82,7 +96,6 @@ class Llama:
|
|||
if token == llama_cpp.llama_token_eos():
|
||||
finish_reason = "stop"
|
||||
break
|
||||
# text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
|
||||
text += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||
self.tokens[prompt_tokens + i] = token
|
||||
completion_tokens += 1
|
||||
|
@ -96,7 +109,7 @@ class Llama:
|
|||
|
||||
llama_cpp.llama_eval(
|
||||
self.ctx,
|
||||
(llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]),
|
||||
(llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]),
|
||||
1,
|
||||
prompt_tokens + completion_tokens,
|
||||
self.n_threads,
|
||||
|
|
Loading…
Reference in a new issue