diff --git a/server/prompt.go b/server/prompt.go index f91b94d8..a6401983 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -27,6 +27,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. isMllama := checkMllamaModelFamily(m) + var imageNumTokens int + // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent + if isMllama { + // Our mllama implementation packs all of the embeddings into a single token + imageNumTokens = 1 + } else { + // Clip images are represented as 768 tokens, each an embedding + imageNumTokens = 768 + } + n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { @@ -59,9 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ctxLen := len(s) if m.ProjectorPaths != nil { for _, m := range msgs[i:] { - // images are represented as 768 sized embeddings - // TODO: get embedding length from project metadata - ctxLen += 768 * len(m.Images) + ctxLen += imageNumTokens * len(m.Images) } }