Implement prompt batch processing as in main.cpp
This commit is contained in:
parent
a28cb92d8f
commit
e24c581b5a
1 changed files with 21 additions and 8 deletions
|
@ -19,6 +19,9 @@ class Llama:
|
||||||
):
|
):
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
|
||||||
|
self.last_n = 64
|
||||||
|
self.max_chunk_size = 32
|
||||||
|
|
||||||
self.params = llama_cpp.llama_context_default_params()
|
self.params = llama_cpp.llama_context_default_params()
|
||||||
self.params.n_ctx = n_ctx
|
self.params.n_ctx = n_ctx
|
||||||
self.params.n_parts = n_parts
|
self.params.n_parts = n_parts
|
||||||
|
@ -59,21 +62,32 @@ class Llama:
|
||||||
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
|
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
|
||||||
)
|
)
|
||||||
|
|
||||||
if prompt_tokens + max_tokens > self.params.n_ctx:
|
if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested tokens exceed context window of {self.params.n_ctx}"
|
f"Requested tokens exceed context window of {self.params.n_ctx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(prompt_tokens):
|
# Process prompt in chunks to avoid running out of memory
|
||||||
llama_cpp.llama_eval(
|
for i in range(0, prompt_tokens, self.max_chunk_size):
|
||||||
self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads
|
chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)]
|
||||||
|
rc = llama_cpp.llama_eval(
|
||||||
|
self.ctx,
|
||||||
|
(llama_cpp.llama_token * len(chunk))(*chunk),
|
||||||
|
len(chunk),
|
||||||
|
max(0, i - 1),
|
||||||
|
self.n_threads,
|
||||||
)
|
)
|
||||||
|
if rc != 0:
|
||||||
|
raise RuntimeError(f"Failed to evaluate prompt: {rc}")
|
||||||
|
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
|
tokens_seen = prompt_tokens + completion_tokens
|
||||||
|
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)]
|
||||||
|
|
||||||
token = llama_cpp.llama_sample_top_p_top_k(
|
token = llama_cpp.llama_sample_top_p_top_k(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
self.tokens,
|
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
|
||||||
prompt_tokens + completion_tokens,
|
len(last_n_tokens),
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
|
@ -82,7 +96,6 @@ class Llama:
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
# text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
|
|
||||||
text += llama_cpp.llama_token_to_str(self.ctx, token)
|
text += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||||
self.tokens[prompt_tokens + i] = token
|
self.tokens[prompt_tokens + i] = token
|
||||||
completion_tokens += 1
|
completion_tokens += 1
|
||||||
|
@ -96,7 +109,7 @@ class Llama:
|
||||||
|
|
||||||
llama_cpp.llama_eval(
|
llama_cpp.llama_eval(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
(llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]),
|
(llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]),
|
||||||
1,
|
1,
|
||||||
prompt_tokens + completion_tokens,
|
prompt_tokens + completion_tokens,
|
||||||
self.n_threads,
|
self.n_threads,
|
||||||
|
|
Loading…
Reference in a new issue