llm: allow gemma 2 to context shift (#5534)
This commit is contained in:
parent
571dc61955
commit
d8def1ff94
1 changed files with 1 additions and 28 deletions
29
llm/ext_server/server.cpp
vendored
29
llm/ext_server/server.cpp
vendored
|
@ -1688,22 +1688,8 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
char buf[256];
|
|
||||||
llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
|
||||||
bool gemma2 = strcmp(buf, "gemma2") == 0;
|
|
||||||
|
|
||||||
int32_t truncate_at = slot.n_ctx;
|
|
||||||
|
|
||||||
// truncate at 2/3 of the context length for gemma2 models
|
|
||||||
// as they do not support context shifts (from the sliding window implementation).
|
|
||||||
// this way, prompts that almost fit the context length can still generate a full
|
|
||||||
// response without a sudden stop from hitting the context limit
|
|
||||||
if (gemma2) {
|
|
||||||
truncate_at = 2 * slot.n_ctx / 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
||||||
{
|
{
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
const int n_shift = n_left / 2;
|
const int n_shift = n_left / 2;
|
||||||
|
@ -1731,19 +1717,6 @@ struct llama_server_context
|
||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Models with sliding window attention do not work with context shifts, so
|
|
||||||
// limit their prediction to the context length
|
|
||||||
if (gemma2) {
|
|
||||||
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
|
||||||
slot.n_predict = limit;
|
|
||||||
slot.params.n_predict = limit;
|
|
||||||
LOG_INFO("model does not support sliding window, limiting generation", {
|
|
||||||
{"n_ctx", slot.n_ctx},
|
|
||||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
||||||
{"n_predict", slot.n_predict}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
|
Loading…
Reference in a new issue