llm: allow gemma 2 to context shift (#5534)

This commit is contained in:
Jeffrey Morgan 2024-07-07 13:41:51 -04:00 committed by GitHub
parent 571dc61955
commit d8def1ff94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1688,22 +1688,8 @@ struct llama_server_context
} }
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
char buf[256];
llama_model_meta_val_str(model, "general.architecture", buf, 256);
bool gemma2 = strcmp(buf, "gemma2") == 0;
int32_t truncate_at = slot.n_ctx;
// truncate at 2/3 of the context length for gemma2 models
// as they do not support context shifts (from the sliding window implementation).
// this way, prompts that almost fit the context length can still generate a full
// response without a sudden stop from hitting the context limit
if (gemma2) {
truncate_at = 2 * slot.n_ctx / 3;
}
// if input prompt is too big, truncate it, if group attention self-extend is disabled // if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{ {
const int n_left = slot.n_ctx - slot.params.n_keep; const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_shift = n_left / 2; const int n_shift = n_left / 2;
@ -1731,19 +1717,6 @@ struct llama_server_context
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
} }
// Models with sliding window attention do not work with context shifts, so
// limit their prediction to the context length
if (gemma2) {
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
slot.n_predict = limit;
slot.params.n_predict = limit;
LOG_INFO("model does not support sliding window, limiting generation", {
{"n_ctx", slot.n_ctx},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"n_predict", slot.n_predict}
});
}
if (!slot.params.cache_prompt) if (!slot.params.cache_prompt)
{ {
llama_sampling_reset(slot.ctx_sampling); llama_sampling_reset(slot.ctx_sampling);