Implement prompt batch processing as in main.cpp

This commit is contained in:
Andrei Betlen 2023-03-24 14:33:38 -04:00
parent a28cb92d8f
commit e24c581b5a

View file

@ -19,6 +19,9 @@ class Llama:
): ):
self.model_path = model_path self.model_path = model_path
self.last_n = 64
self.max_chunk_size = 32
self.params = llama_cpp.llama_context_default_params() self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx self.params.n_ctx = n_ctx
self.params.n_parts = n_parts self.params.n_parts = n_parts
@ -59,21 +62,32 @@ class Llama:
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
) )
if prompt_tokens + max_tokens > self.params.n_ctx: if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
raise ValueError( raise ValueError(
f"Requested tokens exceed context window of {self.params.n_ctx}" f"Requested tokens exceed context window of {self.params.n_ctx}"
) )
for i in range(prompt_tokens): # Process prompt in chunks to avoid running out of memory
llama_cpp.llama_eval( for i in range(0, prompt_tokens, self.max_chunk_size):
self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)]
rc = llama_cpp.llama_eval(
self.ctx,
(llama_cpp.llama_token * len(chunk))(*chunk),
len(chunk),
max(0, i - 1),
self.n_threads,
) )
if rc != 0:
raise RuntimeError(f"Failed to evaluate prompt: {rc}")
for i in range(max_tokens): for i in range(max_tokens):
tokens_seen = prompt_tokens + completion_tokens
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)]
token = llama_cpp.llama_sample_top_p_top_k( token = llama_cpp.llama_sample_top_p_top_k(
self.ctx, self.ctx,
self.tokens, (llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
prompt_tokens + completion_tokens, len(last_n_tokens),
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temperature, temp=temperature,
@ -82,7 +96,6 @@ class Llama:
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
finish_reason = "stop" finish_reason = "stop"
break break
# text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
text += llama_cpp.llama_token_to_str(self.ctx, token) text += llama_cpp.llama_token_to_str(self.ctx, token)
self.tokens[prompt_tokens + i] = token self.tokens[prompt_tokens + i] = token
completion_tokens += 1 completion_tokens += 1
@ -96,7 +109,7 @@ class Llama:
llama_cpp.llama_eval( llama_cpp.llama_eval(
self.ctx, self.ctx,
(llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]), (llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]),
1, 1,
prompt_tokens + completion_tokens, prompt_tokens + completion_tokens,
self.n_threads, self.n_threads,