diff --git a/cmd/cmd.go b/cmd/cmd.go index dc288e43..4d3a48c7 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -21,7 +21,6 @@ import ( "path/filepath" "regexp" "runtime" - "slices" "strconv" "strings" "sync/atomic" @@ -453,7 +452,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - opts.MultiModal = slices.Contains(info.Details.Families, "clip") + opts.MultiModal = len(info.ProjectorInfo) != 0 opts.ParentModel = info.Details.ParentModel if interactive { diff --git a/cmd/interactive.go b/cmd/interactive.go index 1b4b187b..abbf05f4 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -494,28 +494,22 @@ func buildModelfile(opts runOptions) string { } func normalizeFilePath(fp string) string { - // Define a map of escaped characters and their replacements - replacements := map[string]string{ - "\\ ": " ", // Escaped space - "\\(": "(", // Escaped left parenthesis - "\\)": ")", // Escaped right parenthesis - "\\[": "[", // Escaped left square bracket - "\\]": "]", // Escaped right square bracket - "\\{": "{", // Escaped left curly brace - "\\}": "}", // Escaped right curly brace - "\\$": "$", // Escaped dollar sign - "\\&": "&", // Escaped ampersand - "\\;": ";", // Escaped semicolon - "\\'": "'", // Escaped single quote - "\\\\": "\\", // Escaped backslash - "\\*": "*", // Escaped asterisk - "\\?": "?", // Escaped question mark - } - - for escaped, actual := range replacements { - fp = strings.ReplaceAll(fp, escaped, actual) - } - return fp + return strings.NewReplacer( + "\\ ", " ", // Escaped space + "\\(", "(", // Escaped left parenthesis + "\\)", ")", // Escaped right parenthesis + "\\[", "[", // Escaped left square bracket + "\\]", "]", // Escaped right square bracket + "\\{", "{", // Escaped left curly brace + "\\}", "}", // Escaped right curly brace + "\\$", "$", // Escaped dollar sign + "\\&", "&", // Escaped ampersand + "\\;", ";", // Escaped semicolon + "\\'", "'", // Escaped single quote + "\\\\", "\\", // Escaped backslash + "\\*", "*", // Escaped asterisk + "\\?", "?", // Escaped question mark + ).Replace(fp) } func extractFileNames(input string) []string { @@ -535,10 +529,9 @@ func extractFileData(input string) (string, []api.ImageData, error) { for _, fp := range filePaths { nfp := normalizeFilePath(fp) data, err := getImageData(nfp) - if err != nil { - if os.IsNotExist(err) { - continue - } + if errors.Is(err, os.ErrNotExist) { + continue + } else if err != nil { fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) return "", imgs, err } @@ -546,7 +539,7 @@ func extractFileData(input string) (string, []api.ImageData, error) { input = strings.ReplaceAll(input, fp, "") imgs = append(imgs, data) } - return input, imgs, nil + return strings.TrimSpace(input), imgs, nil } func getImageData(filePath string) ([]byte, error) { diff --git a/convert/convert_test.go b/convert/convert_test.go index 2969673d..48a2b1d4 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -29,7 +29,7 @@ type tensorData struct { Shape []int `json:"shape"` } -func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) { +func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) { t.Helper() f, err := os.CreateTemp(t.TempDir(), "f16") @@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) { return r, m.KV(), m.Tensors() } -func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors llm.Tensors) map[string]string { +func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string { actual := make(map[string]string) for k, v := range kv { if s, ok := v.(json.Marshaler); !ok { diff --git a/go.mod b/go.mod index 6e437c73..1e61d3cb 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c + golang.org/x/image v0.14.0 ) require ( diff --git a/go.sum b/go.sum index 926ed26d..d4d1c9a9 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= diff --git a/llama/ggml-cuda.cu b/llama/ggml-cuda.cu index 48258b11..179dfab5 100644 --- a/llama/ggml-cuda.cu +++ b/llama/ggml-cuda.cu @@ -2296,6 +2296,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PAD: ggml_cuda_op_pad(ctx, dst); break; + case GGML_OP_UNPAD: + ggml_cuda_op_unpad(ctx, dst); + break; case GGML_OP_ARANGE: ggml_cuda_op_arange(ctx, dst); break; @@ -3018,6 +3021,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_UNPAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: diff --git a/llama/ggml-cuda/pad.cu b/llama/ggml-cuda/pad.cu index 07fd81d3..c0a5b464 100644 --- a/llama/ggml-cuda/pad.cu +++ b/llama/ggml-cuda/pad.cu @@ -73,3 +73,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } + +static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { + // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 + // blockIdx.y: idx of ne1 + // blockIDx.x: idx of ne0 / BLOCK_SIZE + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } +} + +static void unpad_f32_cuda(const float * x, float * dst, + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2*ne3); + unpad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); +} + +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + unpad_f32_cuda(src0_d, dst_d, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); +} diff --git a/llama/ggml-cuda/pad.cuh b/llama/ggml-cuda/pad.cuh index 54d83d9c..ab2b1480 100644 --- a/llama/ggml-cuda/pad.cuh +++ b/llama/ggml-cuda/pad.cuh @@ -29,3 +29,4 @@ #define CUDA_PAD_BLOCK_SIZE 256 void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal index bbaf5939..72df3e49 100644 --- a/llama/ggml-metal.metal +++ b/llama/ggml-metal.metal @@ -2055,6 +2055,51 @@ kernel void kernel_pad_f32( } } +kernel void kernel_unpad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } + } + + return; + } +} + kernel void kernel_arange_f32( device char * dst, constant int64_t & ne0, diff --git a/llama/ggml-metal_darwin_arm64.m b/llama/ggml-metal_darwin_arm64.m index d2473aac..1ab065db 100644 --- a/llama/ggml-metal_darwin_arm64.m +++ b/llama/ggml-metal_darwin_arm64.m @@ -219,6 +219,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, + GGML_METAL_KERNEL_TYPE_UNPAD_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32, GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, @@ -715,6 +716,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32, unpad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); @@ -872,6 +874,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx return false; case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_UNPAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: @@ -2681,6 +2684,36 @@ static void ggml_metal_encode_node( const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_UNPAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ARANGE: diff --git a/llama/ggml.c b/llama/ggml.c index 4722c9a1..7f7a20e4 100644 --- a/llama/ggml.c +++ b/llama/ggml.c @@ -3023,6 +3023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "POOL_2D_BACK", "UPSCALE", "PAD", + "UNPAD", "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", @@ -3056,7 +3057,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3117,6 +3118,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_2d_back(x)", "upscale(x)", "pad(x)", + "unpad(x)", "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", @@ -3150,7 +3152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6981,6 +6983,32 @@ struct ggml_tensor * ggml_pad( return result; } +// ggml_unpad + +struct ggml_tensor * ggml_unpad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, int p1, int p2, int p3) { + bool is_node = false; + + if (a->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] - p0, + a->ne[1] - p1, + a->ne[2] - p2, + a->ne[3] - p3); + + result->op = GGML_OP_UNPAD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_arange struct ggml_tensor * ggml_arange( @@ -15338,6 +15366,58 @@ static void ggml_compute_forward_pad( } } +static void ggml_compute_forward_unpad_f32( + const struct ggml_compute_params *params, + struct ggml_tensor *dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } + } + } + } + } +} + +static void ggml_compute_forward_unpad( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_unpad_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} // ggml_compute_forward_arange @@ -17320,6 +17400,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_pad(params, tensor); } break; + case GGML_OP_UNPAD: + { + ggml_compute_forward_unpad(params, tensor); + } break; case GGML_OP_ARANGE: { ggml_compute_forward_arange(params, tensor); @@ -18395,6 +18479,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_UNPAD: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_ARANGE: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -19191,6 +19279,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_UNPAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: diff --git a/llama/ggml.h b/llama/ggml.h index 9cf7085a..73deed07 100644 --- a/llama/ggml.h +++ b/llama/ggml.h @@ -532,6 +532,7 @@ extern "C" { GGML_OP_POOL_2D_BACK, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, + GGML_OP_UNPAD, GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, @@ -1790,6 +1791,15 @@ extern "C" { int p2, int p3); + // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x] + GGML_API struct ggml_tensor * ggml_unpad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] // return: [N, dim] diff --git a/llama/llama.cpp b/llama/llama.cpp index 1cdba5bf..87d0148b 100644 --- a/llama/llama.cpp +++ b/llama/llama.cpp @@ -195,6 +195,7 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_MLLAMA, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -249,6 +250,7 @@ enum llm_arch { static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_MLLAMA, "mllama" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -356,6 +358,7 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, + LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -465,6 +468,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" }, + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -639,6 +643,14 @@ enum llm_tensor { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, LLM_TENSOR_BSKCN_TV, + LLM_TENSOR_CROSS_ATTN_K_NORM, + LLM_TENSOR_CROSS_ATTN_K_PROJ, + LLM_TENSOR_CROSS_ATTN_O_PROJ, + LLM_TENSOR_CROSS_ATTN_Q_NORM, + LLM_TENSOR_CROSS_ATTN_Q_PROJ, + LLM_TENSOR_CROSS_ATTN_V_PROJ, + LLM_TENSOR_CROSS_ATTN_ATTN_GATE, + LLM_TENSOR_CROSS_ATTN_MLP_GATE, }; static const std::map> LLM_TENSOR_NAMES = { @@ -668,6 +680,40 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MLLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" }, + { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" }, + { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" }, + { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" }, + { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" }, + { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" }, + { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" }, + { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" }, + }, + }, { LLM_ARCH_BAICHUAN, { @@ -2416,6 +2462,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_90B, MODEL_236B, MODEL_314B, MODEL_SMALL, @@ -2460,6 +2507,7 @@ struct llama_hparams { std::array n_ff_arr; std::array, 4> n_bskcn_arr; + std::array cross_attn_layers; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -2528,10 +2576,11 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; - if (this->n_head_arr != other.n_head_arr) return true; - if (this->n_head_kv_arr != other.n_head_kv_arr) return true; - if (this->n_ff_arr != other.n_ff_arr) return true; - if (this->n_bskcn_arr != other.n_bskcn_arr) return true; + if (this->n_head_arr != other.n_head_arr) return true; + if (this->n_head_kv_arr != other.n_head_kv_arr) return true; + if (this->n_ff_arr != other.n_ff_arr) return true; + if (this->n_bskcn_arr != other.n_bskcn_arr) return true; + if (this->cross_attn_layers != other.cross_attn_layers) return true; if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; @@ -2649,6 +2698,10 @@ struct llama_hparams { GGML_ABORT("fatal error"); } + + bool cross_attention_layer(uint32_t il) const { + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); + } }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); @@ -2832,6 +2885,16 @@ struct llama_layer { struct ggml_tensor * ffn_down_scale; struct ggml_tensor * bskcn_tv; + + // cross attention + struct ggml_tensor * cross_attn_k_norm; + struct ggml_tensor * cross_attn_k_proj; + struct ggml_tensor * cross_attn_o_proj; + struct ggml_tensor * cross_attn_q_norm; + struct ggml_tensor * cross_attn_q_proj; + struct ggml_tensor * cross_attn_v_proj; + struct ggml_tensor * cross_attn_attn_gate; + struct ggml_tensor * cross_attn_mlp_gate; }; // very similar to llama_batch, @@ -3478,6 +3541,12 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + // TODO (jmorganca): this should most likely be passed in as part of a batch + // and not set on the context for all batches. + float * cross_attn_state = nullptr; + bool cross_attn_state_first_pass = true; + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] }; struct llama_lora_weight { @@ -3712,6 +3781,18 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { + // for cross attention layers + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -5486,12 +5567,14 @@ static void llm_load_hparams( } // zero-out the per-layer hparams - std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); - std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); - std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; @@ -5540,7 +5623,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5580,6 +5663,16 @@ static void llm_load_hparams( } } } break; + case LLM_ARCH_MLLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_11B; break; + case 100: model.type = e_model::MODEL_90B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -7275,6 +7368,55 @@ static bool llm_load_tensors( layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; + case LLM_ARCH_MLLAMA: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + if (hparams.cross_attention_layer(i)) { + layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); + layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); + layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); + layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}); + layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}); + layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}); + layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}); + layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + } else { + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } + } break; case LLM_ARCH_GROK: { if (n_expert == 0) { @@ -9119,7 +9261,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && model.hparams.n_vocab != model.vocab.id_to_token.size()) { - throw std::runtime_error("vocab size mismatch"); + LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size()); } if (params.vocab_only) { @@ -9204,7 +9346,7 @@ static struct ggml_tensor * llm_build_inp_embd( inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); } @@ -9219,6 +9361,22 @@ static struct ggml_tensor * llm_build_inp_embd( return inpL; } +static struct ggml_tensor * llm_build_inp_cross_attn_state( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_hparams & hparams, + const llm_build_cb & cb) { + const int64_t n_embd = hparams.n_embd; + + struct ggml_tensor * inpCAS; + lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); + cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); + ggml_set_input(lctx.inp_cross_attn_state); + inpCAS = lctx.inp_cross_attn_state; + + return inpCAS; +} + static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -10193,6 +10351,7 @@ struct llm_build_context { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_cross_attn_state = nullptr; } void free() { @@ -10780,6 +10939,253 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_mllama() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + struct ggml_tensor * inpCAS; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.cross_attention_layer(il)) { + if (!lctx.cross_attn_state) { + continue; + } + + // cross attention layer + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + cb(Qcur, "Qcur", il); + + // TODO: is this required? + Qcur = ggml_cont(ctx0, Qcur); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur; + if (lctx.cross_attn_state_first_pass) { + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); + cb(Kcur, "Kcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); + cb(Kcur, "Kcur", il); + + Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + cb(Kcur, "Kcur", il); + + // TODO: is this required? + Kcur = ggml_cont(ctx0, Kcur); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); + } else { + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); + cb(Kcur, "Kcur (view)", il); + } + + struct ggml_tensor * Vcur; + if (lctx.cross_attn_state_first_pass) { + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404); + cb(Vcur, "Vcur", il); + + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + cb(Vcur, "Vcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); + } else { + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); + cb(Vcur, "Vcur (view)", il); + } + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); + cb(kq, "kq", il); + + kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head))); + cb(kq, "kq_scaled", il); + + // TODO: apply causal masks + struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); + cb(kq_soft_max, "kq_soft_max", il); + + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur); + cb(cur, "cur", il); + + // TODO: do this in place once? + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate)); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + // TODO: do this inplace once? + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } else { + // self attention layer + + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + // lm_head cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); @@ -16527,6 +16933,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_llama(); } break; + case LLM_ARCH_MLLAMA: + { + result = llm.build_mllama(); + } break; case LLM_ARCH_BAICHUAN: { result = llm.build_baichuan(); @@ -16799,6 +17209,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + // TODO (jmorganca): this might copy a lot of data on every request of a + // single generation even though it doesn't change, so we should + // find a way to not set this more than one time per image + if (lctx.inp_cross_attn_state && + lctx.inp_cross_attn_state->buffer) { + ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state)); + } + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -17481,6 +17899,10 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); + // TODO: replace with something better to find out if its + // our first actual pass + lctx.cross_attn_state_first_pass = false; + llama_graph_compute(lctx, gf, n_threads, threadpool); // update the kv ring buffer @@ -18674,7 +19096,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + if (qs.n_attention_wv != n_attn_layer) { + LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv); + } } size_t total_size_org = 0; @@ -19770,6 +20194,11 @@ struct llama_context * llama_new_context_with_model( return ctx; } +void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) { + ctx->cross_attn_state_first_pass = true; + ctx->cross_attn_state = cross_attn_state; +} + void llama_free(struct llama_context * ctx) { delete ctx; } @@ -19840,6 +20269,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_MLLAMA: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: diff --git a/llama/llama.go b/llama/llama.go index c3e3ab87..47786a36 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -60,7 +60,9 @@ package llama #include #include "llama.h" #include "clip.h" +#include "ggml.h" #include "llava.h" +#include "mllama.h" #include "sampling_ext.h" bool llamaProgressCallback(float progress, void *user_data); @@ -410,18 +412,60 @@ func Quantize(infile, outfile string, ftype uint32) error { // llava type ClipContext struct { - c *C.struct_clip_ctx + c *C.struct_clip_ctx + m *C.struct_mllama_ctx + IsMllama bool + embedPin runtime.Pinner + pinned bool } -func NewClipContext(modelPath string) *ClipContext { +func getVisionArch(mp *C.char) (string, error) { + gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)}) + if gguf_ctx == nil { + return "", errors.New("unable to load vision projector") + } + defer C.gguf_free(gguf_ctx) + + arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture")) + if int(arch_index) < 0 { + return "", errors.New("unknown vision model architecture") + } + + arch := C.gguf_get_val_str(gguf_ctx, arch_index) + + return C.GoString(arch), nil +} + +func NewClipContext(modelPath string) (*ClipContext, error) { mp := C.CString(modelPath) defer C.free(unsafe.Pointer(mp)) - cc := C.clip_model_load(mp, 1) - return &ClipContext{c: cc} + + arch, err := getVisionArch(mp) + if err != nil { + return nil, err + } + + var cc ClipContext + if arch == "clip" { + cc.c = C.clip_model_load(mp, 1) + } else if arch == "mllama" { + cc.m = C.mllama_model_load(mp, 1) + cc.IsMllama = true + } else { + return nil, fmt.Errorf("unknown vision model architecture: %s", arch) + } + + // XXX: check embedding size? + return &cc, nil } func (c *ClipContext) Free() { - C.clip_free(c.c) + if c.c != nil { + C.clip_free(c.c) + } + if c.m != nil { + C.mllama_free(c.m) + } } func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 { @@ -445,6 +489,48 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data [] return embed } +func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 { + img := C.mllama_image_init() + defer C.mllama_image_free(img) + + C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img) + + numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m)) + numEmbed := llamaContext.Model().NEmbd() + + rows := make([]float32, numEmbed*numTokens) + C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))) + + embed := make([][]float32, numTokens) + for i := range embed { + embed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + + return embed +} + +// This really needs to be set on a batch instead +func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) { + if embed != nil { + if clipContext.pinned { + panic("Cross attention state already pinned") + } + + embedData := &embed[0][0] + clipContext.embedPin.Pin(embedData) + clipContext.pinned = true + + C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData))) + } else { + C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL)) + + if clipContext.pinned { + clipContext.embedPin.Unpin() + clipContext.pinned = false + } + } +} + // sampling // TODO: this is a temporary wrapper to allow calling C++ code from CGo type SamplingContext struct { diff --git a/llama/llama.h b/llama/llama.h index db90a41a..5f04fc86 100644 --- a/llama/llama.h +++ b/llama/llama.h @@ -449,6 +449,10 @@ extern "C" { struct llama_model * model, struct llama_context_params params); + // TODO (jmorganca): this should most likely be passed in as part of a batch + // and not set on the context for all batches. + LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/llama/mllama.cpp b/llama/mllama.cpp new file mode 100644 index 00000000..d4f2ed2a --- /dev/null +++ b/llama/mllama.cpp @@ -0,0 +1,900 @@ +// NOTE: This is modified from clip.cpp for Mllama only +#include "mllama.h" + +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#define REQUIRE(x) \ + do { \ + if (!(x)) { \ + throw std::runtime_error("REQUIRE failed: " #x); \ + } \ + } while (0) + +#define LOG(fmt, ...) fprintf(stderr, "%s: " fmt "\n", __func__, ##__VA_ARGS__) + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include +#if __GLIBCXX__ +#include +#include +#include +#endif +#endif + +struct mllama_image { + int width; + int height; + + int num_channels = 3; + int num_tiles = 4; + + int aspect_ratio_id; + + std::vector data; +}; + +static std::string format(const char *fmt, ...) { + va_list args; + va_start(args, fmt); + std::vector b(128); + int n = vsnprintf(b.data(), b.size(), fmt, args); + REQUIRE(n >= 0 && n < b.size()); + va_end(args); + return std::string(b.data(), b.size()); +} + +// +// utilities to get data from a gguf file +// + +static int get_key_index(const gguf_context *ctx, const char *key) { + int key_index = gguf_find_key(ctx, key); + REQUIRE(key_index != -1); + return key_index; +} + +static std::vector get_u32_array(const gguf_context *ctx, const std::string &key) { + const int i = get_key_index(ctx, key.c_str()); + const int n = gguf_get_arr_n(ctx, i); + const uint32_t *data = (uint32_t *)gguf_get_arr_data(ctx, i); + + std::vector s(n); + for (size_t j = 0; j < s.size(); j++) { + s[j] = data[j]; + } + + return s; +} + +static uint32_t get_u32(const gguf_context *ctx, const std::string &key) { + return gguf_get_val_u32(ctx, get_key_index(ctx, key.c_str())); +} + +static float get_f32(const gguf_context *ctx, const std::string &key) { + return gguf_get_val_f32(ctx, get_key_index(ctx, key.c_str())); +} + +static std::string get_ftype(int ftype) { + return ggml_type_name(static_cast(ftype)); +} + +// +// mllama layers +// + +struct mllama_hparams { + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t n_global_layer; + uint32_t n_tiles; + + float eps; + + std::vector intermediate_layers; +}; + +struct mllama_layer { + // attention + struct ggml_tensor *k_w; + struct ggml_tensor *k_b; + struct ggml_tensor *q_w; + struct ggml_tensor *q_b; + struct ggml_tensor *v_w; + struct ggml_tensor *v_b; + + struct ggml_tensor *o_w; + struct ggml_tensor *o_b; + + struct ggml_tensor *attn_gate; + + // layernorm 1 + struct ggml_tensor *ln_1_w; + struct ggml_tensor *ln_1_b; + + // ff + struct ggml_tensor *ff_i_w; + struct ggml_tensor *ff_i_b; + + struct ggml_tensor *ff_o_w; + struct ggml_tensor *ff_o_b; + + struct ggml_tensor *ff_gate; + + // layernorm 2 + struct ggml_tensor *ln_2_w; + struct ggml_tensor *ln_2_b; +}; + +struct mllama_vision_model { + struct mllama_hparams hparams; + + // embeddings + struct ggml_tensor *class_embedding; + struct ggml_tensor *patch_embeddings; + struct ggml_tensor *position_embeddings; + struct ggml_tensor *position_embeddings_gate; + struct ggml_tensor *tile_position_embeddings; + struct ggml_tensor *tile_position_embeddings_gate; + struct ggml_tensor *pre_tile_position_embeddings; + struct ggml_tensor *pre_tile_position_embeddings_gate; + struct ggml_tensor *post_tile_position_embeddings; + struct ggml_tensor *post_tile_position_embeddings_gate; + + struct ggml_tensor *pre_ln_w; + struct ggml_tensor *pre_ln_b; + + std::vector layers; + std::vector global_layers; + + struct ggml_tensor *post_ln_w; + struct ggml_tensor *post_ln_b; + + struct ggml_tensor *mm_0_w; + struct ggml_tensor *mm_0_b; +}; + +struct mllama_ctx { + struct mllama_vision_model vision_model; + + uint32_t ftype = 1; + + struct gguf_context *ctx_gguf; + struct ggml_context *ctx_data; + + std::vector buf_compute_meta; + + // memory buffers to evaluate the model + ggml_backend_buffer_t params_buffer = nullptr; + + ggml_backend_t backend = nullptr; + ggml_gallocr_t compute_alloc = nullptr; +}; + +static ggml_tensor *mllama_image_build_encoder_layer( + struct ggml_context *ctx0, const size_t il, const struct mllama_layer &layer, struct ggml_tensor *embeddings, + const float eps, const int hidden_size, const int batch_size, const int n_head, const int d_head) { + struct ggml_tensor *cur = embeddings; + + { + // layernorm1 + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_1_w), layer.ln_1_b); + ggml_set_name(cur, format("%d pre layernorm", il).c_str()); + } + + { + // self-attention + struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b != nullptr) { + Q = ggml_add(ctx0, Q, layer.q_b); + } + + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, Q->ne[1], batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + ggml_set_name(Q, format("%d query", il).c_str()); + + struct ggml_tensor *K = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b != nullptr) { + K = ggml_add(ctx0, K, layer.k_b); + } + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, K->ne[1], batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + ggml_set_name(K, format("%d key", il).c_str()); + + struct ggml_tensor *V = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b != nullptr) { + V = ggml_add(ctx0, V, layer.v_b); + } + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, V->ne[1], batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + ggml_set_name(V, format("%d value", il).c_str()); + + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head)); + KQ = ggml_soft_max_inplace(ctx0, KQ); + ggml_set_name(KQ, format("%d KQ", il).c_str()); + + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, KQV->ne[1], n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, KQV->ne[2], batch_size); + ggml_set_name(KQV, format("%d KQV", il).c_str()); + + cur = ggml_mul_mat(ctx0, layer.o_w, KQV); + if (layer.o_b != nullptr) { + cur = ggml_add(ctx0, cur, layer.o_b); + } + ggml_set_name(cur, format("%d self attention", il).c_str()); + + if (layer.attn_gate != nullptr) { + cur = ggml_mul_inplace(ctx0, cur, layer.attn_gate); + ggml_set_name(cur, format("%d self attention gate", il).c_str()); + } + } + + cur = ggml_add(ctx0, cur, embeddings); + ggml_set_name(cur, format("%d residual", il).c_str()); + + embeddings = cur; + + { + // layernorm2 + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_2_w), layer.ln_2_b); + ggml_set_name(cur, format("%d post layernorm", il).c_str()); + } + + { + // feed forward + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, layer.ff_i_w, cur), layer.ff_i_b); + cur = ggml_gelu_inplace(ctx0, cur); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, layer.ff_o_w, cur), layer.ff_o_b); + ggml_set_name(cur, format("%d feed forward", il).c_str()); + + if (layer.ff_gate != nullptr) { + cur = ggml_mul_inplace(ctx0, cur, layer.ff_gate); + ggml_set_name(cur, format("%d feed forward gate", il).c_str()); + } + } + + // residual 2 + cur = ggml_add(ctx0, cur, embeddings); + ggml_set_name(cur, format("%d residual", il).c_str()); + + embeddings = cur; + + return embeddings; +} + +static ggml_cgraph *mllama_image_build_graph(mllama_ctx *ctx, const mllama_image_batch *imgs) { + const auto &model = ctx->vision_model; + const auto &hparams = model.hparams; + + const int image_size = hparams.image_size; + const int image_size_width = image_size; + const int image_size_height = image_size; + + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (model.class_embedding == nullptr ? 0 : 1); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + + const int batch_size = imgs->size; + REQUIRE(batch_size == 1); + + int num_tiles = 4; + int num_channels = 3; + if (imgs->data != nullptr) { + num_tiles = imgs->data[0].num_tiles > 0 ? imgs->data[0].num_tiles : num_tiles; + num_channels = imgs->data[0].num_channels > 0 ? imgs->data[0].num_channels : num_channels; + } + + struct ggml_init_params params = { + ctx->buf_compute_meta.size(), // mem_size + ctx->buf_compute_meta.data(), // mem_buffer + true, // no_alloc + }; + + struct ggml_context *ctx0 = ggml_init(params); + struct ggml_cgraph *gf = ggml_new_graph(ctx0); + + struct ggml_tensor *inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, num_channels, num_tiles); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor *inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, num_tiles); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + struct ggml_tensor *aspect_ratios = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, imgs->size); + ggml_set_name(aspect_ratios, "aspect_ratios"); + ggml_set_input(aspect_ratios); + + if (model.pre_tile_position_embeddings != nullptr) { + struct ggml_tensor *pre_tile_position_embeddings = ggml_get_rows(ctx0, model.pre_tile_position_embeddings, aspect_ratios); + ggml_set_name(pre_tile_position_embeddings, "pre_tile_position_embeddings"); + + pre_tile_position_embeddings = ggml_reshape_3d(ctx0, pre_tile_position_embeddings, hidden_size, 1, num_tiles); + if (model.pre_tile_position_embeddings_gate != nullptr) { + pre_tile_position_embeddings = ggml_mul_inplace(ctx0, pre_tile_position_embeddings, model.pre_tile_position_embeddings_gate); + } + + inp = ggml_add(ctx0, inp, pre_tile_position_embeddings); + } + + struct ggml_tensor *embeddings = inp; + + if (model.class_embedding != nullptr) { + // concat class_embeddings and patch_embeddings + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, num_tiles); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + for (int i = 0; i < num_tiles; ++i) { + // repeat class embeddings for each tile + embeddings = ggml_acc_inplace(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], i * embeddings->nb[2]); + } + + embeddings = ggml_acc_inplace(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + struct ggml_tensor *position_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); + if (model.position_embeddings_gate != nullptr) { + position_embd = ggml_mul_inplace(ctx0, position_embd, model.position_embeddings_gate); + } + + embeddings = ggml_add(ctx0, embeddings, position_embd); + + if (model.tile_position_embeddings != nullptr) { + struct ggml_tensor *tile_position_embeddings = ggml_get_rows(ctx0, model.tile_position_embeddings, aspect_ratios); + ggml_set_name(tile_position_embeddings, "tile_position_embeddings"); + + tile_position_embeddings = ggml_reshape_3d(ctx0, tile_position_embeddings, hidden_size, num_positions, num_tiles); + if (model.tile_position_embeddings_gate != nullptr) { + tile_position_embeddings = ggml_mul_inplace(ctx0, tile_position_embeddings, model.tile_position_embeddings_gate); + } + + embeddings = ggml_add(ctx0, embeddings, tile_position_embeddings); + } + + // pre-layernorm + if (model.pre_ln_w != nullptr) { + embeddings = ggml_mul(ctx0, ggml_norm(ctx0, embeddings, hparams.eps), model.pre_ln_w); + if (model.pre_ln_b != nullptr) { + embeddings = ggml_add(ctx0, embeddings, model.pre_ln_b); + } + + ggml_set_name(embeddings, "pre layernorm"); + } + + const int num_padding_patches = 8 - (embeddings->ne[1] % 8) % 8; + + embeddings = ggml_pad(ctx0, embeddings, 0, num_padding_patches, 0, 0); + embeddings = ggml_view_3d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1] * embeddings->ne[2], batch_size, embeddings->nb[1], embeddings->nb[2] * embeddings->ne[3], 0); + + std::vector intermediate_embeddings; + + // encoder + for (size_t il = 0; il < model.layers.size(); il++) { + if (hparams.intermediate_layers[il]) { + intermediate_embeddings.push_back(embeddings); + } + + embeddings = mllama_image_build_encoder_layer( + ctx0, il, model.layers[il], embeddings, + hparams.eps, hidden_size, batch_size, n_head, d_head); + } + + // post-layernorm + if (model.post_ln_w != nullptr) { + embeddings = ggml_mul(ctx0, ggml_norm(ctx0, embeddings, hparams.eps), model.post_ln_w); + if (model.post_ln_b != nullptr) { + embeddings = ggml_add(ctx0, embeddings, model.post_ln_b); + } + + ggml_set_name(embeddings, "post layernorm"); + } + + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_positions + num_padding_patches, num_tiles); + + if (model.post_tile_position_embeddings != nullptr) { + struct ggml_tensor *post_tile_position_embeddings = ggml_get_rows(ctx0, model.post_tile_position_embeddings, aspect_ratios); + ggml_set_name(post_tile_position_embeddings, "post_tile_position_embeddings"); + + post_tile_position_embeddings = ggml_reshape_3d(ctx0, post_tile_position_embeddings, hidden_size, 1, num_tiles); + if (model.post_tile_position_embeddings_gate != nullptr) { + post_tile_position_embeddings = ggml_mul(ctx0, post_tile_position_embeddings, model.post_tile_position_embeddings_gate); + } + + embeddings = ggml_add(ctx0, embeddings, post_tile_position_embeddings); + } + + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_tiles * (num_positions + num_padding_patches), 1); + + // global encoder + for (size_t il = 0; il < model.global_layers.size(); il++) { + embeddings = mllama_image_build_encoder_layer( + ctx0, il, model.global_layers[il], embeddings, + hparams.eps, hidden_size, batch_size, n_head, d_head); + } + + struct ggml_tensor *stacked_embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 0, hidden_size, (num_positions + num_padding_patches) * num_tiles); + for (size_t i = 0; i < intermediate_embeddings.size(); ++i) { + stacked_embeddings = ggml_concat(ctx0, stacked_embeddings, ggml_reshape_3d(ctx0, intermediate_embeddings[i], 1, intermediate_embeddings[i]->ne[0], intermediate_embeddings[i]->ne[1]), 0); + } + + stacked_embeddings = ggml_reshape_4d(ctx0, stacked_embeddings, intermediate_embeddings.size() * hidden_size, num_positions + num_padding_patches, num_tiles, batch_size); + stacked_embeddings = ggml_unpad(ctx0, stacked_embeddings, 0, num_padding_patches, 0, 0); + + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_positions + num_padding_patches, num_tiles); + embeddings = ggml_unpad(ctx0, embeddings, 0, num_padding_patches, 0, 0); + embeddings = ggml_concat(ctx0, embeddings, stacked_embeddings, 0); + + // mllama projector + embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_0_w, embeddings), model.mm_0_b); + ggml_set_name(embeddings, "multi modal projector"); + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_tensor *mllama_tensor_load(struct ggml_context *ctx, const char *name, const bool optional) { + struct ggml_tensor *cur = ggml_get_tensor(ctx, name); + REQUIRE(cur != nullptr || optional); + return cur; +} + +static std::vector mllama_layers_load(struct ggml_context *ctx, const char *prefix, const int n) { + std::vector layers(n); + for (size_t i = 0; i < layers.size(); i++) { + auto &layer = layers[i]; + layer.ln_1_w = mllama_tensor_load(ctx, format("%s.blk.%d.ln1.weight", prefix, i).c_str(), false); + layer.ln_1_b = mllama_tensor_load(ctx, format("%s.blk.%d.ln1.bias", prefix, i).c_str(), false); + layer.ln_2_w = mllama_tensor_load(ctx, format("%s.blk.%d.ln2.weight", prefix, i).c_str(), false); + layer.ln_2_b = mllama_tensor_load(ctx, format("%s.blk.%d.ln2.bias", prefix, i).c_str(), false); + + layer.k_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_k.weight", prefix, i).c_str(), false); + layer.k_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_k.bias", prefix, i).c_str(), true); + layer.q_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_q.weight", prefix, i).c_str(), false); + layer.q_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_q.bias", prefix, i).c_str(), true); + layer.v_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_v.weight", prefix, i).c_str(), false); + layer.v_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_v.bias", prefix, i).c_str(), true); + layer.o_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_out.weight", prefix, i).c_str(), false); + layer.o_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_out.bias", prefix, i).c_str(), true); + + layer.ff_i_w = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_down.weight", prefix, i).c_str(), false); + layer.ff_i_b = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_down.bias", prefix, i).c_str(), false); + layer.ff_o_w = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_up.weight", prefix, i).c_str(), false); + layer.ff_o_b = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_up.bias", prefix, i).c_str(), false); + + layer.attn_gate = mllama_tensor_load(ctx, format("%s.blk.%d.attn_gate", prefix, i).c_str(), true); + layer.ff_gate = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_gate", prefix, i).c_str(), true); + } + + return layers; +} + +// read and create ggml_context containing the tensors and their data +struct mllama_ctx *mllama_model_load(const char *fname, const int verbosity = 1) { + struct ggml_context *meta = nullptr; + + struct gguf_init_params params = { + true, // no_alloc + &meta, // ctx + }; + + struct gguf_context *ctx = gguf_init_from_file(fname, params); + REQUIRE(ctx != nullptr); + + if (verbosity >= 1) { + const int n_tensors = gguf_get_n_tensors(ctx); + const int n_kv = gguf_get_n_kv(ctx); + const std::string ftype = get_ftype(get_u32(ctx, "general.file_type")); + const int idx_desc = get_key_index(ctx, "general.description"); + const std::string description = gguf_get_val_str(ctx, idx_desc); + const int idx_name = gguf_find_key(ctx, "general.name"); + if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug + const std::string name = gguf_get_val_str(ctx, idx_name); + LOG("model name: %s", name.c_str()); + } + LOG("description: %s", description.c_str()); + LOG("GGUF version: %d", gguf_get_version(ctx)); + LOG("alignment: %zu", gguf_get_alignment(ctx)); + LOG("n_tensors: %d", n_tensors); + LOG("n_kv: %d", n_kv); + LOG("ftype: %s", ftype.c_str()); + LOG(""); + } + const int n_tensors = gguf_get_n_tensors(ctx); + + mllama_ctx *new_mllama = new mllama_ctx{}; + +#ifdef GGML_USE_CUDA + new_mllama->backend = ggml_backend_cuda_init(0); + LOG("vision using CUDA backend"); +#endif + +#ifdef GGML_USE_METAL + new_mllama->backend = ggml_backend_metal_init(); + LOG("vision using Metal backend"); +#endif + +#ifdef GGML_USE_CANN + new_mllama->backend = ggml_backend_cann_init(0); + LOG("vision using CANN backend"); +#endif + +#ifdef GGML_USE_VULKAN + new_mllama->backend = ggml_backend_vk_init(0); + LOG("vision using Vulkan backend"); +#endif + + if (!new_mllama->backend) { + new_mllama->backend = ggml_backend_cpu_init(); + LOG("vision using CPU backend"); + } + + // load tensors + { + std::vector read_buf; + struct ggml_init_params params = { + (n_tensors + 1) * ggml_tensor_overhead(), // mem_size + nullptr, // mem_buffer + true, // no_alloc + }; + + new_mllama->ctx_data = ggml_init(params); + if (!new_mllama->ctx_data) { + LOG("ggml_init() failed"); + mllama_free(new_mllama); + gguf_free(ctx); + return nullptr; + } + +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0); + if (!wlen) { + return NULL; + } + wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t)); + wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen); + if (!wlen) { + free(wbuf); + return NULL; + } +#if __GLIBCXX__ + int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY); + __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in); + std::istream fin(&buffer); +#else // MSVC + // unused in our current build + auto fin = std::ifstream(wbuf, std::ios::binary); +#endif + free(wbuf); +#else + auto fin = std::ifstream(fname, std::ios::binary); +#endif + if (!fin) { + LOG("cannot open model file for loading tensors\n"); + mllama_free(new_mllama); + gguf_free(ctx); + return nullptr; + } + + // add tensors to context + for (int i = 0; i < n_tensors; ++i) { + const char *name = gguf_get_tensor_name(ctx, i); + struct ggml_tensor *t = ggml_get_tensor(meta, name); + struct ggml_tensor *cur = ggml_dup_tensor(new_mllama->ctx_data, t); + ggml_set_name(cur, name); + } + + // alloc memory and offload data + new_mllama->params_buffer = ggml_backend_alloc_ctx_tensors(new_mllama->ctx_data, new_mllama->backend); + for (int i = 0; i < n_tensors; ++i) { + const char *name = gguf_get_tensor_name(ctx, i); + struct ggml_tensor *cur = ggml_get_tensor(new_mllama->ctx_data, name); + const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); + fin.seekg(offset, std::ios::beg); + if (!fin) { + LOG("failed to seek for tensor %s\n", name); + mllama_free(new_mllama); + gguf_free(ctx); + return nullptr; + } + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buffer_is_host(new_mllama->params_buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } + +#if defined(_WIN32) && defined(__GLIBCXX__) + close(fd); +#else + fin.close(); +#endif + } + + // vision model + // load vision model + auto &vision_model = new_mllama->vision_model; + auto &hparams = vision_model.hparams; + hparams.hidden_size = get_u32(ctx, "mllama.vision.embedding_length"); + hparams.n_head = get_u32(ctx, "mllama.vision.attention.head_count"); + hparams.n_intermediate = get_u32(ctx, "mllama.vision.feed_forward_length"); + hparams.n_layer = get_u32(ctx, "mllama.vision.block_count"); + hparams.n_global_layer = get_u32(ctx, "mllama.vision.global.block_count"); + hparams.n_tiles = get_u32(ctx, "mllama.vision.max_num_tiles"); + hparams.image_size = get_u32(ctx, "mllama.vision.image_size"); + hparams.patch_size = get_u32(ctx, "mllama.vision.patch_size"); + hparams.projection_dim = get_u32(ctx, "mllama.vision.projection_dim"); + hparams.eps = get_f32(ctx, "mllama.vision.attention.layer_norm_epsilon"); + + std::vector intermediate_layers_indices = get_u32_array(ctx, "mllama.vision.intermediate_layers_indices"); + hparams.intermediate_layers.resize(hparams.n_layer); + for (size_t i = 0; i < intermediate_layers_indices.size(); i++) { + hparams.intermediate_layers[intermediate_layers_indices[i]] = true; + } + + if (verbosity >= 2) { + LOG(""); + LOG("vision model hparams"); + LOG("image_size %d", hparams.image_size); + LOG("patch_size %d", hparams.patch_size); + LOG("v_hidden_size %d", hparams.hidden_size); + LOG("v_n_intermediate %d", hparams.n_intermediate); + LOG("v_projection_dim %d", hparams.projection_dim); + LOG("v_n_head %d", hparams.n_head); + LOG("v_n_layer %d", hparams.n_layer); + LOG("v_n_global_layer %d", hparams.n_global_layer); + LOG("v_eps %f", hparams.eps); + } + + vision_model.class_embedding = mllama_tensor_load(new_mllama->ctx_data, "v.class_embd", true); + vision_model.patch_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.patch_embd.weight", true); + + vision_model.position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.position_embd.weight", true); + vision_model.position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.position_embd.gate", true); + + vision_model.pre_ln_w = mllama_tensor_load(new_mllama->ctx_data, "v.pre_ln.weight", true); + vision_model.pre_ln_b = mllama_tensor_load(new_mllama->ctx_data, "v.pre_ln.bias", true); + vision_model.post_ln_w = mllama_tensor_load(new_mllama->ctx_data, "v.post_ln.weight", true); + vision_model.post_ln_b = mllama_tensor_load(new_mllama->ctx_data, "v.post_ln.bias", true); + + vision_model.tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.tile_position_embd.weight", true); + vision_model.tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.tile_position_embd.gate", true); + + vision_model.pre_tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.pre_tile_position_embd.weight", true); + vision_model.pre_tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.pre_tile_position_embd.gate", true); + + vision_model.post_tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.post_tile_position_embd.weight", true); + vision_model.post_tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.post_tile_position_embd.gate", true); + + vision_model.mm_0_w = mllama_tensor_load(new_mllama->ctx_data, "mm.0.weight", false); + vision_model.mm_0_b = mllama_tensor_load(new_mllama->ctx_data, "mm.0.bias", false); + + vision_model.layers = mllama_layers_load(new_mllama->ctx_data, "v", hparams.n_layer); + vision_model.global_layers = mllama_layers_load(new_mllama->ctx_data, "v.global", hparams.n_global_layer); + + ggml_free(meta); + + new_mllama->ctx_gguf = ctx; + + { + // measure mem requirement and allocate + new_mllama->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); + new_mllama->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_mllama->backend)); + struct mllama_image_batch batch; + batch.size = 1; + ggml_cgraph *gf = mllama_image_build_graph(new_mllama, &batch); + ggml_gallocr_reserve(new_mllama->compute_alloc, gf); + size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_mllama->compute_alloc, 0); + LOG("compute allocated memory: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + } + + return new_mllama; +} + +struct mllama_image *mllama_image_init() { + return new mllama_image(); +} + +void mllama_image_free(struct mllama_image *img) { delete img; } +void mllama_image_batch_free(struct mllama_image_batch *batch) { + if (batch->size > 0) { + delete[] batch->data; + batch->size = 0; + } +} + +bool mllama_image_load_from_data(const void *data, const int n, const int width, const int height, const int num_channels, const int num_tiles, const int aspect_ratio_id, struct mllama_image *img) { + img->width = width; + img->height = height; + img->num_channels = num_channels; + img->num_tiles = num_tiles; + img->aspect_ratio_id = aspect_ratio_id; + img->data.resize(n); + + memcpy(img->data.data(), data, n); + return true; +} + +inline int mllama(int x, int lower, int upper) { + return std::max(lower, std::min(x, upper)); +} + +void mllama_free(mllama_ctx *ctx) { + ggml_free(ctx->ctx_data); + gguf_free(ctx->ctx_gguf); + + ggml_backend_buffer_free(ctx->params_buffer); + ggml_backend_free(ctx->backend); + ggml_gallocr_free(ctx->compute_alloc); + delete ctx; +} + +bool mllama_image_encode(struct mllama_ctx *ctx, const int n_threads, mllama_image *img, float *vec) { + mllama_image_batch imgs{}; + imgs.size = 1; + imgs.data = img; + return mllama_image_batch_encode(ctx, n_threads, &imgs, vec); +} + +bool mllama_image_batch_encode(mllama_ctx *ctx, const int n_threads, const mllama_image_batch *imgs, float *vec) { + int batch_size = imgs->size; + REQUIRE(batch_size == 1); + + // build the inference graph + ggml_cgraph *gf = mllama_image_build_graph(ctx, imgs); + ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + + // set inputs + const auto &model = ctx->vision_model; + const auto &hparams = model.hparams; + + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (model.class_embedding == nullptr ? 0 : 1); + + { + struct ggml_tensor *inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + ggml_backend_tensor_set(inp_raw, imgs->data[0].data.data(), 0, ggml_nbytes(inp_raw)); + } + + { + struct ggml_tensor *embeddings = ggml_graph_get_tensor(gf, "embeddings"); + if (embeddings != nullptr) { + void *zeros = malloc(ggml_nbytes(embeddings)); + memset(zeros, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zeros, 0, ggml_nbytes(embeddings)); + free(zeros); + } + } + + { + struct ggml_tensor *positions = ggml_graph_get_tensor(gf, "positions"); + if (positions != nullptr) { + int *positions_data = (int *)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + } + + { + struct ggml_tensor *aspect_ratios = ggml_graph_get_tensor(gf, "aspect_ratios"); + if (aspect_ratios != nullptr) { + int *aspect_ratios_data = (int *)malloc(ggml_nbytes(aspect_ratios)); + aspect_ratios_data[0] = imgs->data[0].aspect_ratio_id; + ggml_backend_tensor_set(aspect_ratios, aspect_ratios_data, 0, ggml_nbytes(aspect_ratios)); + free(aspect_ratios_data); + } + } + + if (ggml_backend_is_cpu(ctx->backend)) { + ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); + } + + ggml_backend_graph_compute(ctx->backend, gf); + + // the last node is the embedding tensor + struct ggml_tensor *embeddings = ggml_graph_node(gf, ggml_graph_n_nodes(gf) - 1); + + // copy the embeddings to the location passed by the user + ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + + return true; +} + +int32_t mllama_image_size(const struct mllama_ctx *ctx) { + return ctx->vision_model.hparams.image_size; +} + +int32_t mllama_patch_size(const struct mllama_ctx *ctx) { + return ctx->vision_model.hparams.patch_size; +} + +int32_t mllama_hidden_size(const struct mllama_ctx *ctx) { + return ctx->vision_model.hparams.hidden_size; +} + +int mllama_n_patches(const struct mllama_ctx *ctx) { + const auto &hparams = ctx->vision_model.hparams; + return (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); +} + +int mllama_n_positions(const struct mllama_ctx *ctx) { + return mllama_n_patches(ctx) + (ctx->vision_model.class_embedding == nullptr ? 0 : 1); +} + +int mllama_n_tiles(const struct mllama_ctx *ctx) { + return ctx->vision_model.hparams.n_tiles; +} + +int mllama_n_embd(const struct mllama_ctx *ctx) { + return ctx->vision_model.hparams.projection_dim; +} + +size_t mllama_n_embd_bytes(const struct mllama_ctx *ctx) { + return mllama_n_positions(ctx) * mllama_n_embd(ctx) * mllama_n_tiles(ctx) * sizeof(float); +} diff --git a/llama/mllama.h b/llama/mllama.h new file mode 100644 index 00000000..446dbb9e --- /dev/null +++ b/llama/mllama.h @@ -0,0 +1,61 @@ +#ifndef MLLAMA_H +#define MLLAMA_H + +#include +#include + +#ifdef LLAMA_SHARED +#if defined(_WIN32) && !defined(__MINGW32__) +#ifdef LLAMA_BUILD +#define MLLAMA_API __declspec(dllexport) +#else +#define MLLAMA_API __declspec(dllimport) +#endif +#else +#define MLLAMA_API __attribute__((visibility("default"))) +#endif +#else +#define MLLAMA_API +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct mllama_ctx; + +struct mllama_image_batch { + struct mllama_image *data; + size_t size; +}; + +MLLAMA_API struct mllama_ctx *mllama_model_load(const char *fname, int verbosity); +MLLAMA_API struct mllama_ctx *mllama_model_load_cpu(const char *fname, int verbosity); + +MLLAMA_API void mllama_free(struct mllama_ctx *ctx); + +MLLAMA_API int32_t mllama_image_size(const struct mllama_ctx *ctx); +MLLAMA_API int32_t mllama_patch_size(const struct mllama_ctx *ctx); +MLLAMA_API int32_t mllama_hidden_size(const struct mllama_ctx *ctx); + +MLLAMA_API int mllama_n_patches(const struct mllama_ctx *ctx); +MLLAMA_API int mllama_n_positions(const struct mllama_ctx *ctx); +MLLAMA_API int mllama_n_tiles(const struct mllama_ctx *ctx); +MLLAMA_API int mllama_n_embd(const struct mllama_ctx *ctx); +MLLAMA_API size_t mllama_n_embd_bytes(const struct mllama_ctx *ctx); + +MLLAMA_API struct mllama_image *mllama_image_init(); + +MLLAMA_API void mllama_image_free(struct mllama_image *img); +MLLAMA_API void mllama_image_batch_free(struct mllama_image_batch *batch); + +MLLAMA_API bool mllama_image_load_from_data(const void *data, const int n, const int nx, const int ny, const int nc, const int nt, const int aspect_ratio_id, struct mllama_image *img); + +MLLAMA_API bool mllama_image_encode(struct mllama_ctx *ctx, int n_threads, struct mllama_image *img, float *vec); +MLLAMA_API bool mllama_image_batch_encode(struct mllama_ctx *ctx, int n_threads, const struct mllama_image_batch *imgs, float *vec); + +#ifdef __cplusplus +} +#endif + +#endif // MLLAMA_H diff --git a/llama/patches/0010-add-mllama-support.patch b/llama/patches/0010-add-mllama-support.patch new file mode 100644 index 00000000..c6dd72a7 --- /dev/null +++ b/llama/patches/0010-add-mllama-support.patch @@ -0,0 +1,690 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: jmorganca +Date: Thu, 17 Oct 2024 15:18:22 -0700 +Subject: [PATCH] add mllama support + +mllama adds cross-attention layers to the standard llama architecture +it also requires a way to input a new tensor: cross_attention_state +once per generation + +cross-attention layers don't change and so they are cached in the +kv cache once per run + +remaining is to implement the cross attention mask +--- + include/llama.h | 4 + + src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++-- + 2 files changed, 447 insertions(+), 13 deletions(-) + +diff --git a/include/llama.h b/include/llama.h +index 7cae1bbe..122e3cf1 100644 +--- a/include/llama.h ++++ b/include/llama.h +@@ -423,6 +423,10 @@ extern "C" { + struct llama_model * model, + struct llama_context_params params); + ++ // TODO (jmorganca): this should most likely be passed in as part of a batch ++ // and not set on the context for all batches. ++ LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); ++ + // Frees all allocated memory + LLAMA_API void llama_free(struct llama_context * ctx); + +diff --git a/src/llama.cpp b/src/llama.cpp +index 83b80b59..b189a19a 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) { + + enum llm_arch { + LLM_ARCH_LLAMA, ++ LLM_ARCH_MLLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, +@@ -223,6 +224,7 @@ enum llm_arch { + + static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, ++ { LLM_ARCH_MLLAMA, "mllama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, +@@ -330,6 +332,7 @@ enum llm_kv { + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, ++ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, +@@ -439,6 +442,7 @@ static const std::map LLM_KV_NAMES = { + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" }, ++ { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, +@@ -613,6 +617,14 @@ enum llm_tensor { + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_BSKCN_TV, ++ LLM_TENSOR_CROSS_ATTN_K_NORM, ++ LLM_TENSOR_CROSS_ATTN_K_PROJ, ++ LLM_TENSOR_CROSS_ATTN_O_PROJ, ++ LLM_TENSOR_CROSS_ATTN_Q_NORM, ++ LLM_TENSOR_CROSS_ATTN_Q_PROJ, ++ LLM_TENSOR_CROSS_ATTN_V_PROJ, ++ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, ++ LLM_TENSOR_CROSS_ATTN_MLP_GATE, + }; + + static const std::map> LLM_TENSOR_NAMES = { +@@ -642,6 +654,40 @@ static const std::map> LLM_TENSOR_NA + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, ++ { ++ LLM_ARCH_MLLAMA, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_OUTPUT, "output" }, ++ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, ++ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, ++ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, ++ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, ++ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, ++ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, ++ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, ++ { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" }, ++ { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" }, ++ { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" }, ++ { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" }, ++ { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" }, ++ { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" }, ++ { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" }, ++ { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" }, ++ }, ++ }, + { + LLM_ARCH_BAICHUAN, + { +@@ -2390,6 +2436,7 @@ enum e_model { + MODEL_40B, + MODEL_65B, + MODEL_70B, ++ MODEL_90B, + MODEL_236B, + MODEL_314B, + MODEL_SMALL, +@@ -2434,6 +2481,7 @@ struct llama_hparams { + std::array n_ff_arr; + + std::array, 4> n_bskcn_arr; ++ std::array cross_attn_layers; + + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; +@@ -2502,10 +2550,11 @@ struct llama_hparams { + if (this->n_expert != other.n_expert) return true; + if (this->n_expert_used != other.n_expert_used) return true; + +- if (this->n_head_arr != other.n_head_arr) return true; +- if (this->n_head_kv_arr != other.n_head_kv_arr) return true; +- if (this->n_ff_arr != other.n_ff_arr) return true; +- if (this->n_bskcn_arr != other.n_bskcn_arr) return true; ++ if (this->n_head_arr != other.n_head_arr) return true; ++ if (this->n_head_kv_arr != other.n_head_kv_arr) return true; ++ if (this->n_ff_arr != other.n_ff_arr) return true; ++ if (this->n_bskcn_arr != other.n_bskcn_arr) return true; ++ if (this->cross_attn_layers != other.cross_attn_layers) return true; + + if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; + if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; +@@ -2623,6 +2672,10 @@ struct llama_hparams { + + GGML_ABORT("fatal error"); + } ++ ++ bool cross_attention_layer(uint32_t il) const { ++ return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); ++ } + }; + + static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); +@@ -2806,6 +2859,16 @@ struct llama_layer { + struct ggml_tensor * ffn_down_scale; + + struct ggml_tensor * bskcn_tv; ++ ++ // cross attention ++ struct ggml_tensor * cross_attn_k_norm; ++ struct ggml_tensor * cross_attn_k_proj; ++ struct ggml_tensor * cross_attn_o_proj; ++ struct ggml_tensor * cross_attn_q_norm; ++ struct ggml_tensor * cross_attn_q_proj; ++ struct ggml_tensor * cross_attn_v_proj; ++ struct ggml_tensor * cross_attn_attn_gate; ++ struct ggml_tensor * cross_attn_mlp_gate; + }; + + // very similar to llama_batch, +@@ -3452,6 +3515,12 @@ struct llama_context { + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] ++ ++ // TODO (jmorganca): this should most likely be passed in as part of a batch ++ // and not set on the context for all batches. ++ float * cross_attn_state = nullptr; ++ bool cross_attn_state_first_pass = true; ++ struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] + }; + + struct llama_lora_weight { +@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init( + cache.v_l.reserve(n_layer); + + for (int i = 0; i < (int) n_layer; i++) { ++ // for cross attention layers ++ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { ++ struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ++ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); ++ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); ++ ggml_format_name(k, "cache_k_l%d", i); ++ ggml_format_name(v, "cache_v_l%d", i); ++ cache.k_l.push_back(k); ++ cache.v_l.push_back(v); ++ continue; ++ } ++ + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + +@@ -5460,12 +5541,14 @@ static void llm_load_hparams( + } + + // zero-out the per-layer hparams +- std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); +- std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); +- std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); ++ std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); ++ std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); ++ std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); ++ std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1); + +- ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); +- ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); ++ ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); ++ ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); ++ ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv_arr = hparams.n_head_arr; +@@ -5514,7 +5597,7 @@ static void llm_load_hparams( + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + +- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { ++ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + } +@@ -5554,6 +5637,16 @@ static void llm_load_hparams( + } + } + } break; ++ case LLM_ARCH_MLLAMA: ++ { ++ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ++ ++ switch (hparams.n_layer) { ++ case 40: model.type = e_model::MODEL_11B; break; ++ case 100: model.type = e_model::MODEL_90B; break; ++ default: model.type = e_model::MODEL_UNKNOWN; ++ } ++ } break; + case LLM_ARCH_MINICPM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +@@ -7249,6 +7342,55 @@ static bool llm_load_tensors( + layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + } break; ++ case LLM_ARCH_MLLAMA: ++ { ++ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}); ++ ++ // output ++ { ++ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); ++ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); ++ ++ // if output is NULL, init from the input tok embed ++ if (model.output == NULL) { ++ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); ++ } ++ } ++ ++ for (int i = 0; i < n_layer; ++i) { ++ ggml_context * ctx_layer = ctx_for_layer(i); ++ ggml_context * ctx_split = ctx_for_layer_split(i); ++ ++ auto & layer = model.layers[i]; ++ ++ if (hparams.cross_attention_layer(i)) { ++ layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); ++ layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); ++ layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); ++ layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}); ++ layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}); ++ layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}); ++ layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}); ++ layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}); ++ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); ++ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); ++ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); ++ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); ++ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); ++ } else { ++ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); ++ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); ++ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); ++ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); ++ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); ++ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); ++ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); ++ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); ++ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); ++ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); ++ } ++ } ++ } break; + case LLM_ARCH_GROK: + { + if (n_expert == 0) { +@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam + + if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && + model.hparams.n_vocab != model.vocab.id_to_token.size()) { +- throw std::runtime_error("vocab size mismatch"); ++ LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size()); + } + + if (params.vocab_only) { +@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd( + + inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); + } else { +- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); ++ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + inpL = lctx.inp_embd; + ggml_set_input(lctx.inp_embd); + } +@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd( + return inpL; + } + ++static struct ggml_tensor * llm_build_inp_cross_attn_state( ++ struct ggml_context * ctx, ++ struct llama_context & lctx, ++ const llama_hparams & hparams, ++ const llm_build_cb & cb) { ++ const int64_t n_embd = hparams.n_embd; ++ ++ struct ggml_tensor * inpCAS; ++ lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); ++ cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); ++ ggml_set_input(lctx.inp_cross_attn_state); ++ inpCAS = lctx.inp_cross_attn_state; ++ ++ return inpCAS; ++} ++ + static void llm_build_kv_store( + struct ggml_context * ctx, + const llama_hparams & hparams, +@@ -10167,6 +10325,7 @@ struct llm_build_context { + lctx.inp_pos_bucket = nullptr; + lctx.inp_embd_enc = nullptr; + lctx.inp_KQ_mask_cross = nullptr; ++ lctx.inp_cross_attn_state = nullptr; + } + + void free() { +@@ -10754,6 +10913,253 @@ struct llm_build_context { + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + ++ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); ++ cb(cur, "result_output", -1); ++ ++ ggml_build_forward_expand(gf, cur); ++ ++ return gf; ++ } ++ ++ struct ggml_cgraph * build_mllama() { ++ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); ++ ++ // mutable variable, needed during the last layer of the computation to skip unused tokens ++ int32_t n_tokens = this->n_tokens; ++ ++ const int64_t n_embd_head = hparams.n_embd_head_v; ++ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ++ GGML_ASSERT(n_embd_head == hparams.n_rot); ++ ++ struct ggml_tensor * cur; ++ struct ggml_tensor * inpL; ++ struct ggml_tensor * inpCAS; ++ ++ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ++ inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb); ++ ++ // inp_pos - contains the positions ++ struct ggml_tensor * inp_pos = build_inp_pos(); ++ ++ // KQ_mask (mask for 1 head, it will be broadcasted to all heads) ++ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); ++ ++ for (int il = 0; il < n_layer; ++il) { ++ struct ggml_tensor * inpSA = inpL; ++ ++ // norm ++ cur = llm_build_norm(ctx0, inpL, hparams, ++ model.layers[il].attn_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "attn_norm", il); ++ ++ if (hparams.cross_attention_layer(il)) { ++ if (!lctx.cross_attn_state) { ++ continue; ++ } ++ ++ // cross attention layer ++ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur); ++ cb(Qcur, "Qcur", il); ++ ++ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); ++ cb(Qcur, "Qcur", il); ++ ++ Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ++ cb(Qcur, "Qcur", il); ++ ++ // TODO: is this required? ++ Qcur = ggml_cont(ctx0, Qcur); ++ cb(Qcur, "Qcur", il); ++ ++ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); ++ cb(Qcur, "Qcur", il); ++ ++ struct ggml_tensor * Kcur; ++ if (lctx.cross_attn_state_first_pass) { ++ Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); ++ cb(Kcur, "Kcur", il); ++ ++ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); ++ cb(Kcur, "Kcur", il); ++ ++ Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); ++ cb(Kcur, "Kcur", il); ++ ++ // TODO: is this required? ++ Kcur = ggml_cont(ctx0, Kcur); ++ cb(Kcur, "Kcur", il); ++ ++ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); ++ cb(Kcur, "Kcur", il); ++ ++ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); ++ } else { ++ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); ++ cb(Kcur, "Kcur (view)", il); ++ } ++ ++ struct ggml_tensor * Vcur; ++ if (lctx.cross_attn_state_first_pass) { ++ Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); ++ cb(Vcur, "Vcur", il); ++ ++ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404); ++ cb(Vcur, "Vcur", il); ++ ++ Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); ++ cb(Vcur, "Vcur", il); ++ ++ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); ++ } else { ++ Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); ++ cb(Vcur, "Vcur (view)", il); ++ } ++ ++ struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); ++ cb(kq, "kq", il); ++ ++ kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head))); ++ cb(kq, "kq_scaled", il); ++ ++ // TODO: apply causal masks ++ struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); ++ cb(kq_soft_max, "kq_soft_max", il); ++ ++ Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); ++ cb(Vcur, "Vcur", il); ++ ++ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max); ++ cb(kqv, "kqv", il); ++ ++ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); ++ cb(kqv_merged, "kqv_merged", il); ++ ++ cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); ++ cb(cur, "kqv_merged_cont", il); ++ ++ cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur); ++ cb(cur, "cur", il); ++ ++ // TODO: do this in place once? ++ cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate)); ++ ++ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); ++ cb(ffn_inp, "ffn_inp", il); ++ ++ // feed-forward network ++ cur = llm_build_norm(ctx0, ffn_inp, hparams, ++ model.layers[il].ffn_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "ffn_norm", il); ++ ++ cur = llm_build_ffn(ctx0, lctx, cur, ++ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, ++ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, ++ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, ++ NULL, ++ LLM_FFN_SILU, LLM_FFN_PAR, cb, il); ++ cb(cur, "ffn_out", il); ++ ++ // TODO: do this inplace once? ++ cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp); ++ cb(cur, "ffn_out", il); ++ ++ cur = lctx.cvec.apply_to(ctx0, cur, il); ++ cb(cur, "l_out", il); ++ ++ // input for next layer ++ inpL = cur; ++ } else { ++ // self attention layer ++ ++ // rope freq factors for llama3; may return nullptr for llama2 and other models ++ struct ggml_tensor * rope_factors = build_rope_factors(il); ++ ++ // compute Q and K and RoPE them ++ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); ++ cb(Qcur, "Qcur", il); ++ if (model.layers[il].bq) { ++ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); ++ cb(Qcur, "Qcur", il); ++ } ++ ++ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); ++ cb(Kcur, "Kcur", il); ++ if (model.layers[il].bk) { ++ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); ++ cb(Kcur, "Kcur", il); ++ } ++ ++ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); ++ cb(Vcur, "Vcur", il); ++ if (model.layers[il].bv) { ++ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); ++ cb(Vcur, "Vcur", il); ++ } ++ ++ Qcur = ggml_rope_ext( ++ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, ++ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ++ ext_factor, attn_factor, beta_fast, beta_slow ++ ); ++ cb(Qcur, "Qcur", il); ++ ++ Kcur = ggml_rope_ext( ++ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, ++ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ++ ext_factor, attn_factor, beta_fast, beta_slow ++ ); ++ cb(Kcur, "Kcur", il); ++ ++ cur = llm_build_kv(ctx0, lctx, kv_self, gf, ++ model.layers[il].wo, model.layers[il].bo, ++ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); ++ ++ ++ if (il == n_layer - 1) { ++ // skip computing output for unused tokens ++ struct ggml_tensor * inp_out_ids = build_inp_out_ids(); ++ n_tokens = n_outputs; ++ cur = ggml_get_rows(ctx0, cur, inp_out_ids); ++ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); ++ } ++ ++ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); ++ cb(ffn_inp, "ffn_inp", il); ++ ++ // feed-forward network ++ cur = llm_build_norm(ctx0, ffn_inp, hparams, ++ model.layers[il].ffn_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "ffn_norm", il); ++ ++ cur = llm_build_ffn(ctx0, lctx, cur, ++ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, ++ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, ++ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, ++ NULL, ++ LLM_FFN_SILU, LLM_FFN_PAR, cb, il); ++ cb(cur, "ffn_out", il); ++ ++ cur = ggml_add(ctx0, cur, ffn_inp); ++ cb(cur, "ffn_out", il); ++ ++ cur = lctx.cvec.apply_to(ctx0, cur, il); ++ cb(cur, "l_out", il); ++ ++ // input for next layer ++ inpL = cur; ++ } ++ } ++ ++ cur = inpL; ++ ++ cur = llm_build_norm(ctx0, cur, hparams, ++ model.output_norm, NULL, ++ LLM_NORM_RMS, cb, -1); ++ cb(cur, "result_norm", -1); ++ + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); +@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph( + { + result = llm.build_llama(); + } break; ++ case LLM_ARCH_MLLAMA: ++ { ++ result = llm.build_mllama(); ++ } break; + case LLM_ARCH_BAICHUAN: + { + result = llm.build_baichuan(); +@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { + ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); + } + ++ // TODO (jmorganca): this might copy a lot of data on every request of a ++ // single generation even though it doesn't change, so we should ++ // find a way to not set this more than one time per image ++ if (lctx.inp_cross_attn_state && ++ lctx.inp_cross_attn_state->buffer) { ++ ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state)); ++ } ++ + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + const int64_t n_tokens = batch.n_tokens; +@@ -17455,6 +17873,10 @@ static int llama_decode_internal( + + llama_set_inputs(lctx, ubatch); + ++ // TODO: replace with something better to find out if its ++ // our first actual pass ++ lctx.cross_attn_state_first_pass = false; ++ + llama_graph_compute(lctx, gf, n_threads, threadpool); + + // update the kv ring buffer +@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s + if (llama_model_has_encoder(&model)) { + n_attn_layer *= 3; + } +- GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); ++ if (qs.n_attention_wv != n_attn_layer) { ++ LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv); ++ } + } + + size_t total_size_org = 0; +@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model( + return ctx; + } + ++void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) { ++ ctx->cross_attn_state_first_pass = true; ++ ctx->cross_attn_state = cross_attn_state; ++} ++ + void llama_free(struct llama_context * ctx) { + delete ctx; + } +@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { + + // use what we call a normal RoPE, operating on pairs of consecutive head values + case LLM_ARCH_LLAMA: ++ case LLM_ARCH_MLLAMA: + case LLM_ARCH_BAICHUAN: + case LLM_ARCH_STARCODER: + case LLM_ARCH_PLAMO: diff --git a/llama/patches/0011-add-unpad-operator.patch b/llama/patches/0011-add-unpad-operator.patch new file mode 100644 index 00000000..c362c528 --- /dev/null +++ b/llama/patches/0011-add-unpad-operator.patch @@ -0,0 +1,409 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Thu, 17 Oct 2024 17:19:25 -0700 +Subject: [PATCH] add unpad operator + +--- + ggml/include/ggml.h | 10 ++++ + ggml/src/ggml-cuda.cu | 4 ++ + ggml/src/ggml-cuda/pad.cu | 46 +++++++++++++++++++ + ggml/src/ggml-cuda/pad.cuh | 1 + + ggml/src/ggml-metal.m | 33 ++++++++++++++ + ggml/src/ggml-metal.metal | 45 ++++++++++++++++++ + ggml/src/ggml.c | 93 +++++++++++++++++++++++++++++++++++++- + 7 files changed, 230 insertions(+), 2 deletions(-) + +diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h +index ce3d92cb..962cb5f7 100644 +--- a/ggml/include/ggml.h ++++ b/ggml/include/ggml.h +@@ -506,6 +506,7 @@ extern "C" { + GGML_OP_POOL_2D_BACK, + GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_PAD, ++ GGML_OP_UNPAD, + GGML_OP_ARANGE, + GGML_OP_TIMESTEP_EMBEDDING, + GGML_OP_ARGSORT, +@@ -1764,6 +1765,15 @@ extern "C" { + int p2, + int p3); + ++ // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x] ++ GGML_API struct ggml_tensor * ggml_unpad( ++ struct ggml_context * ctx, ++ struct ggml_tensor * a, ++ int p0, ++ int p1, ++ int p2, ++ int p3); ++ + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 + // timesteps: [N,] + // return: [N, dim] +diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu +index fe77b81c..6e84af56 100644 +--- a/ggml/src/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda.cu +@@ -2270,6 +2270,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg + case GGML_OP_PAD: + ggml_cuda_op_pad(ctx, dst); + break; ++ case GGML_OP_UNPAD: ++ ggml_cuda_op_unpad(ctx, dst); ++ break; + case GGML_OP_ARANGE: + ggml_cuda_op_arange(ctx, dst); + break; +@@ -2992,6 +2995,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons + case GGML_OP_GROUP_NORM: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: +diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu +index aba539e8..39fd4b16 100644 +--- a/ggml/src/ggml-cuda/pad.cu ++++ b/ggml/src/ggml-cuda/pad.cu +@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + } ++ ++static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { ++ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 ++ // blockIdx.y: idx of ne1 ++ // blockIDx.x: idx of ne0 / BLOCK_SIZE ++ int nidx = threadIdx.x + blockIdx.x * blockDim.x; ++ if (nidx >= ne0) { ++ return; ++ } ++ ++ // operation ++ int offset_dst = ++ nidx + ++ blockIdx.y * ne0 + ++ blockIdx.z * ne0 * gridDim.y; ++ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { ++ int offset_src = ++ nidx + ++ blockIdx.y * ne00 + ++ blockIdx.z * ne00 * ne01; ++ dst[offset_dst] = x[offset_src]; ++ } ++} ++ ++static void unpad_f32_cuda(const float * x, float * dst, ++ const int ne00, const int ne01, const int ne02, const int ne03, ++ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { ++ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; ++ dim3 gridDim(num_blocks, ne1, ne2*ne3); ++ unpad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); ++} ++ ++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ++ const ggml_tensor * src0 = dst->src[0]; ++ const float * src0_d = (const float *)src0->data; ++ float * dst_d = (float *)dst->data; ++ cudaStream_t stream = ctx.stream(); ++ ++ GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ GGML_ASSERT(dst->type == GGML_TYPE_F32); ++ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors ++ ++ unpad_f32_cuda(src0_d, dst_d, ++ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ++ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); ++} +diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh +index 8fd386b0..e2ededc3 100644 +--- a/ggml/src/ggml-cuda/pad.cuh ++++ b/ggml/src/ggml-cuda/pad.cuh +@@ -3,3 +3,4 @@ + #define CUDA_PAD_BLOCK_SIZE 256 + + void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); ++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m +index 829c5e39..25702d85 100644 +--- a/ggml/src/ggml-metal.m ++++ b/ggml/src/ggml-metal.m +@@ -193,6 +193,7 @@ + GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + GGML_METAL_KERNEL_TYPE_PAD_F32, ++ GGML_METAL_KERNEL_TYPE_UNPAD_F32, + GGML_METAL_KERNEL_TYPE_ARANGE_F32, + GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, +@@ -689,6 +690,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32, unpad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); +@@ -846,6 +848,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx + return false; + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: +@@ -2655,6 +2658,36 @@ static void ggml_metal_encode_node( + + const int nth = MIN(1024, ne0); + ++ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; ++ } break; ++ case GGML_OP_UNPAD: ++ { ++ GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ ++ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline; ++ ++ [encoder setComputePipelineState:pipeline]; ++ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; ++ [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; ++ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; ++ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; ++ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; ++ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; ++ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; ++ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; ++ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; ++ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; ++ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; ++ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; ++ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; ++ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; ++ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; ++ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; ++ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; ++ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; ++ ++ const int nth = MIN(1024, ne0); ++ + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: +diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal +index 2b200032..09887511 100644 +--- a/ggml/src/ggml-metal.metal ++++ b/ggml/src/ggml-metal.metal +@@ -2029,6 +2029,51 @@ kernel void kernel_pad_f32( + } + } + ++kernel void kernel_unpad_f32( ++ device const char * src0, ++ device char * dst, ++ constant int64_t & ne00, ++ constant int64_t & ne01, ++ constant int64_t & ne02, ++ constant int64_t & ne03, ++ constant uint64_t & nb00, ++ constant uint64_t & nb01, ++ constant uint64_t & nb02, ++ constant uint64_t & nb03, ++ constant int64_t & ne0, ++ constant int64_t & ne1, ++ constant int64_t & ne2, ++ constant int64_t & ne3, ++ constant uint64_t & nb0, ++ constant uint64_t & nb1, ++ constant uint64_t & nb2, ++ constant uint64_t & nb3, ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ uint3 tpitg[[thread_position_in_threadgroup]], ++ uint3 ntg[[threads_per_threadgroup]]) { ++ ++ const int64_t i3 = tgpig.z; ++ const int64_t i2 = tgpig.y; ++ const int64_t i1 = tgpig.x; ++ ++ const int64_t i03 = i3; ++ const int64_t i02 = i2; ++ const int64_t i01 = i1; ++ ++ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); ++ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); ++ ++ if (i1 < ne01 && i2 < ne02 && i3 < ne03) { ++ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { ++ if (i0 < ne00) { ++ dst_ptr[i0] = src0_ptr[i0]; ++ } ++ } ++ ++ return; ++ } ++} ++ + kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, +diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c +index bcbc32d9..f4864ac8 100644 +--- a/ggml/src/ggml.c ++++ b/ggml/src/ggml.c +@@ -2997,6 +2997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "POOL_2D_BACK", + "UPSCALE", + "PAD", ++ "UNPAD", + "ARANGE", + "TIMESTEP_EMBEDDING", + "ARGSORT", +@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "OPT_STEP_ADAMW", + }; + +-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); ++static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); + + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", +@@ -3091,6 +3092,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "pool_2d_back(x)", + "upscale(x)", + "pad(x)", ++ "unpad(x)", + "arange(start, stop, step)", + "timestep_embedding(timesteps, dim, max_period)", + "argsort(x)", +@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "adamw(x)", + }; + +-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); ++static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); + + static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); + +@@ -6955,6 +6957,32 @@ struct ggml_tensor * ggml_pad( + return result; + } + ++// ggml_unpad ++ ++struct ggml_tensor * ggml_unpad( ++ struct ggml_context * ctx, ++ struct ggml_tensor * a, ++ int p0, int p1, int p2, int p3) { ++ bool is_node = false; ++ ++ if (a->grad) { ++ GGML_ABORT("fatal error"); // TODO: implement backward ++ is_node = true; ++ } ++ ++ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ++ a->ne[0] - p0, ++ a->ne[1] - p1, ++ a->ne[2] - p2, ++ a->ne[3] - p3); ++ ++ result->op = GGML_OP_UNPAD; ++ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; ++ result->src[0] = a; ++ ++ return result; ++} ++ + // ggml_arange + + struct ggml_tensor * ggml_arange( +@@ -15312,6 +15340,58 @@ static void ggml_compute_forward_pad( + } + } + ++static void ggml_compute_forward_unpad_f32( ++ const struct ggml_compute_params *params, ++ struct ggml_tensor *dst) { ++ ++ const struct ggml_tensor * src0 = dst->src[0]; ++ ++ GGML_ASSERT(src0->nb[0] == sizeof(float)); ++ GGML_ASSERT( dst->nb[0] == sizeof(float)); ++ ++ const int ith = params->ith; ++ const int nth = params->nth; ++ ++ GGML_TENSOR_UNARY_OP_LOCALS ++ ++ float * dst_ptr = (float *) dst->data; ++ ++ // TODO: optimize ++ ++ for (int64_t i2 = 0; i2 < ne2; ++i2) { ++ for (int64_t i1 = ith; i1 < ne1; i1 += nth) { ++ for (int64_t i0 = 0; i0 < ne0; ++i0) { ++ for (int64_t i3 = 0; i3 < ne3; ++i3) { ++ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; ++ ++ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ++ ++ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { ++ dst_ptr[dst_idx] = *src_ptr; ++ } ++ } ++ } ++ } ++ } ++} ++ ++static void ggml_compute_forward_unpad( ++ const struct ggml_compute_params * params, ++ struct ggml_tensor * dst) { ++ ++ const struct ggml_tensor * src0 = dst->src[0]; ++ ++ switch (src0->type) { ++ case GGML_TYPE_F32: ++ { ++ ggml_compute_forward_unpad_f32(params, dst); ++ } break; ++ default: ++ { ++ GGML_ABORT("fatal error"); ++ } ++ } ++} + + // ggml_compute_forward_arange + +@@ -17294,6 +17374,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm + { + ggml_compute_forward_pad(params, tensor); + } break; ++ case GGML_OP_UNPAD: ++ { ++ ggml_compute_forward_unpad(params, tensor); ++ } break; + case GGML_OP_ARANGE: + { + ggml_compute_forward_arange(params, tensor); +@@ -18369,6 +18453,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor + { + GGML_ABORT("fatal error"); // TODO: not implemented + } ++ case GGML_OP_UNPAD: ++ { ++ GGML_ABORT("fatal error"); // TODO: not implemented ++ } + case GGML_OP_ARANGE: + { + GGML_ABORT("fatal error"); // TODO: not implemented +@@ -19165,6 +19253,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { + } break; + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: diff --git a/llama/runner/runner.go b/llama/runner/runner.go index f4c45e0f..9fb669a2 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -206,6 +206,26 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { } } + if s.clip.cc != nil { + var embed [][]float32 + + if s.clip.cc.IsMllama && len(images) >= 1 { + hash := s.cache.HashImage(images[0].Data) + + s.clip.mu.Lock() + var err error + embed, err = s.cache.FindImage(hash) + if err != nil { + embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID) + s.cache.AddImage(hash, embed) + } + s.clip.mu.Unlock() + } + s.mu.Lock() + llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed) + s.mu.Unlock() + } + return inputs, nil } @@ -294,6 +314,9 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false + if s.clip.cc != nil { + llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil) + } s.seqs[seqIndex] = nil } @@ -517,8 +540,9 @@ type Options struct { } type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` + Data []byte `json:"data"` + ID int `json:"id"` + AspectRatioID int `json:"aspect_ratio_id"` } type CompletionRequest struct { @@ -770,7 +794,11 @@ func (s *Server) loadModel( } if ppath != "" { - s.clip.cc = llama.NewClipContext(ppath) + var err error + s.clip.cc, err = llama.NewClipContext(ppath) + if err != nil { + panic(err) + } } s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) diff --git a/llm/ggla.go b/llm/ggla.go index 831f6071..ec0a5941 100644 --- a/llm/ggla.go +++ b/llm/ggla.go @@ -51,8 +51,8 @@ func (llm *ggla) KV() KV { return llm.kv } -func (llm *ggla) Tensors() Tensors { - return Tensors{ +func (llm *ggla) Tensors() *Tensors { + return &Tensors{ Items: llm.tensors, Offset: llm.tensorOffset, } diff --git a/llm/ggml.go b/llm/ggml.go index aa846b97..e857d4b8 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "io" + "slices" "strings" + "sync" "github.com/ollama/ollama/util/bufioutil" ) @@ -17,7 +19,7 @@ type GGML struct { type model interface { KV() KV - Tensors() Tensors + Tensors() *Tensors } type KV map[string]any @@ -123,25 +125,34 @@ func (kv KV) ChatTemplate() string { type Tensors struct { Items []*Tensor Offset uint64 + + layers map[string]Layer + layersOnce sync.Once } -func (ts Tensors) Layers() map[string]Layer { - layers := make(map[string]Layer) - for _, t := range ts.Items { - parts := strings.Split(t.Name, ".") - if parts[0] == "blk" { - // join first and second part, e.g. blk.%d - parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...) +func (ts *Tensors) Layers() map[string]Layer { + ts.layersOnce.Do(func() { + ts.layers = make(map[string]Layer) + for _, t := range ts.Items { + parts := strings.Split(t.Name, ".") + if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 { + if len(parts) > index+2 { + // blk and mm should have a number after them, join it + parts = append( + []string{strings.Join(parts[:index+2], ".")}, + parts[index+2:]...) + } + } + + if _, ok := ts.layers[parts[0]]; !ok { + ts.layers[parts[0]] = make(Layer) + } + + ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t } + }) - if _, ok := layers[parts[0]]; !ok { - layers[parts[0]] = make(Layer) - } - - layers[parts[0]][strings.Join(parts[1:], ".")] = t - } - - return layers + return ts.layers } type Layer map[string]*Tensor diff --git a/llm/gguf.go b/llm/gguf.go index 2e6bc542..c7a95490 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -110,8 +110,8 @@ func (llm *gguf) KV() KV { return llm.kv } -func (llm *gguf) Tensors() Tensors { - return Tensors{ +func (llm *gguf) Tensors() *Tensors { + return &Tensors{ Items: llm.tensors, Offset: llm.tensorOffset, } diff --git a/llm/memory.go b/llm/memory.go index d8dbf0be..16f9a743 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -3,6 +3,7 @@ package llm import ( "fmt" "log/slog" + "os" "strconv" "strings" @@ -63,6 +64,8 @@ type MemoryEstimate struct { memoryLayerOutput uint64 graphFullOffload uint64 graphPartialOffload uint64 + + projectorWeights, projectorGraph uint64 } // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size @@ -78,7 +81,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, var graphOffload uint64 // Projectors loaded into GPU0 only - var projectorSize uint64 + var projectorWeights uint64 + var projectorGraph uint64 // Conditional output size on GPU 0 var memoryLayerOutput uint64 @@ -103,7 +107,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList) for _, projector := range projectors { - projectorSize += projectorMemoryRequirements(projector) + weight, graph := projectorMemoryRequirements(projector) + projectorWeights += weight + projectorGraph += graph // multimodal models require at least 2048 context opts.NumCtx = max(opts.NumCtx, 2048) @@ -149,7 +155,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, } // Output layer handled at the end if we have space - gpuZeroOverhead := projectorSize + gpuZeroOverhead := projectorWeights + projectorGraph // Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer var layerCount int @@ -303,6 +309,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, memoryLayerOutput: memoryLayerOutput, graphFullOffload: graphFullOffload, graphPartialOffload: graphPartialOffload, + projectorWeights: projectorWeights, + projectorGraph: projectorGraph, } if gpus[0].Library == "cpu" { @@ -323,7 +331,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, func (m MemoryEstimate) log() { overhead := envconfig.GpuOverhead() - slog.Info( + + log := slog.With() + if m.projectorWeights > 0 { + log = log.With( + slog.Group( + "projector", + "weights", format.HumanBytes2(m.projectorWeights), + "graph", format.HumanBytes2(m.projectorGraph), + ), + ) + } + + log.Info( "offload to "+m.inferenceLibrary, slog.Group( "layers", @@ -371,3 +391,52 @@ func (m MemoryEstimate) log() { ), ) } + +func projectorMemoryRequirements(filename string) (weights, graphSize uint64) { + file, err := os.Open(filename) + if err != nil { + return 0, 0 + } + defer file.Close() + + ggml, _, err := DecodeGGML(file, 0) + if err != nil { + return 0, 0 + } + + for _, layer := range ggml.Tensors().Layers() { + weights += layer.size() + } + + switch arch := ggml.KV().Architecture(); arch { + case "mllama": + kv := func(n string) uint64 { + if v, ok := ggml.KV()[arch+".vision."+n].(uint32); ok { + return uint64(v) + } + + return 0 + } + + imageSize := kv("image_size") + + maxNumTiles := kv("max_num_tiles") + embeddingLength := kv("embedding_length") + headCount := kv("attention.head_count") + + numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size")) + if _, ok := ggml.Tensors().Layers()["v"]["class_embd"]; ok { + numPatches++ + } + + numPaddedPatches := numPatches + 8 - (numPatches%8)%8 + + graphSize = 4 * (8 + + imageSize*imageSize*kv("num_channels")*maxNumTiles + + embeddingLength*numPatches*maxNumTiles + + 9*embeddingLength*numPaddedPatches*maxNumTiles + + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) + } + + return weights, graphSize +} diff --git a/llm/server.go b/llm/server.go index a16b5c19..cc4eac90 100644 --- a/llm/server.go +++ b/llm/server.go @@ -442,26 +442,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter return nil, finalErr } -func projectorMemoryRequirements(filename string) uint64 { - file, err := os.Open(filename) - if err != nil { - return 0 - } - defer file.Close() - - ggml, _, err := DecodeGGML(file, 0) - if err != nil { - return 0 - } - - var mem uint64 - for _, layer := range ggml.Tensors().Layers() { - mem += layer.size() - } - - return mem -} - type ServerStatus int const ( // iota is reset to 0 @@ -673,8 +653,9 @@ ws ::= ([ \t\n] ws)? const maxBufferSize = 512 * format.KiloByte type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` + Data []byte `json:"data"` + ID int `json:"id"` + AspectRatioID int `json:"aspect_ratio_id"` } type completion struct { diff --git a/server/imageproc/images.go b/server/imageproc/images.go new file mode 100644 index 00000000..688cbf8a --- /dev/null +++ b/server/imageproc/images.go @@ -0,0 +1,240 @@ +package imageproc + +import ( + "bytes" + "fmt" + "image" + "image/color" + _ "image/jpeg" + _ "image/png" + "math" + "slices" + + "golang.org/x/image/draw" +) + +func GetSupportedAspectRatios(maxTiles int) []image.Point { + ratios := []image.Point{} + + for w := range maxTiles { + for h := range maxTiles { + if (w+1)*(h+1) <= maxTiles { + ratios = append(ratios, image.Point{w + 1, h + 1}) + } + } + } + + return ratios +} + +func clip(a, a_min, a_max int) int { + if a < a_min { + return a_min + } else if a > a_max { + return a_max + } + + return a +} + +func getImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point { + targetWidth := clip(imageSize.X, tileSize, canvasSize.X) + targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y) + + scaleWidth := float64(targetWidth) / float64(imageSize.X) + scaleHeight := float64(targetHeight) / float64(imageSize.Y) + + var w, h int + + if scaleWidth < scaleHeight { + w = targetWidth + h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight) + } else { + w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth) + h = targetHeight + } + + return image.Point{w, h} +} + +func getOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point { + possibleTileArrangements := GetSupportedAspectRatios(maxImageTiles) + possibleCanvasSizes := []image.Point{} + for _, pta := range possibleTileArrangements { + possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize}) + } + + scales := []float64{} + + for _, pcs := range possibleCanvasSizes { + scaleHeight := float64(pcs.Y) / float64(imageSize.Y) + scaleWidth := float64(pcs.X) / float64(imageSize.X) + + if scaleWidth > scaleHeight { + scales = append(scales, scaleHeight) + } else { + scales = append(scales, scaleWidth) + } + } + + var minUpscale float64 + var maxDownscale float64 + var upscale bool + + for _, s := range scales { + if s > 1.0 { + upscale = true + if minUpscale == 0 { + minUpscale = s + } else { + minUpscale = math.Min(minUpscale, s) + } + } else { + maxDownscale = math.Max(maxDownscale, s) + } + } + + selectedScale := maxDownscale + if upscale { + selectedScale = minUpscale + } + + var selectedCanvas image.Point + for n, pcs := range possibleCanvasSizes { + if scales[n] == selectedScale { + // choose the smallest possible canvas + if selectedCanvas.X == 0 && selectedCanvas.Y == 0 { + selectedCanvas = pcs + } else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y { + selectedCanvas = pcs + } + } + } + return selectedCanvas +} + +func splitToTiles(img image.Image, numTilesSize image.Point) []image.Image { + b := img.Bounds() + width := b.Max.X - b.Min.X + height := b.Max.Y - b.Min.Y + tileHeight := height / numTilesSize.Y + tileWidth := width / numTilesSize.X + + images := []image.Image{} + + for h := range numTilesSize.Y { + for w := range numTilesSize.X { + rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1)) + images = append(images, img.(interface { + SubImage(image.Rectangle) image.Image + }).SubImage(rect)) + } + } + + return images +} + +// remove the "alpha" channel by drawing over a prefilled image +func compositeImage(img image.Image) image.Image { + dst := image.NewRGBA(img.Bounds()) + + white := color.RGBA{255, 255, 255, 255} + draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src) + draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over) + + return dst +} + +func ResizeImage(img image.Image, format string, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) { + if format == "png" { + img = compositeImage(img) + } + + b := img.Bounds() + tileSize := outputSize.Y + + canvasSize := getOptimalTiledCanvas(b.Max, maxImageTiles, tileSize) + aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize} + newSize := getImageSizeFitToCanvas(b.Max, canvasSize, tileSize) + + dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y)) + + // scaling choices: + // NearestNeighbor fast, blocky output + // ApproxBiLinear fast, medium quality + // BiLinear slow, high quality + // CatmullRom very slow, very high quality + draw.BiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil) + + return dst, aspectRatio +} + +func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image { + paddedSize := image.Point{ + X: outputSize.X * aspectRatio.X, + Y: outputSize.Y * aspectRatio.Y, + } + + dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y)) + draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over) + + return dst +} + +func PackImages(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 { + subImages := splitToTiles(img, aspectRatio) + + var pixelVals []float32 + + for _, subImg := range subImages { + bounds := subImg.Bounds() + var rVals, gVals, bVals []float32 + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + c := subImg.At(x, y) + r, g, b, _ := c.RGBA() + rVal := float32(r>>8) / 255.0 + gVal := float32(g>>8) / 255.0 + bVal := float32(b>>8) / 255.0 + + rVal = (rVal - mean[0]) / std[0] + gVal = (gVal - mean[1]) / std[1] + bVal = (bVal - mean[2]) / std[2] + + rVals = append(rVals, rVal) + gVals = append(gVals, gVal) + bVals = append(bVals, bVal) + } + } + pixelVals = append(pixelVals, rVals...) + pixelVals = append(pixelVals, gVals...) + pixelVals = append(pixelVals, bVals...) + } + + return pixelVals +} + +func Preprocess(imageData []byte) ([]float32, int, error) { + // todo: need guard in here for bad image data + + // mllama values + outputSize := image.Point{560, 560} + maxTiles := 4 + + // clip values + mean := [3]float32{0.48145466, 0.4578275, 0.40821073} + std := [3]float32{0.26862954, 0.26130258, 0.27577711} + + img, format, err := image.Decode(bytes.NewReader(imageData)) + if err != nil { + return nil, 0, fmt.Errorf("failed to decode image: %w", err) + } + + newImage, aspectRatio := ResizeImage(img, format, outputSize, maxTiles) + newImage = PadImage(newImage, outputSize, aspectRatio) + + data := PackImages(newImage, aspectRatio, mean, std) + aspectRatioIndex := slices.Index(GetSupportedAspectRatios(maxTiles), aspectRatio) + 1 + + return data, aspectRatioIndex, nil +} diff --git a/server/imageproc/images_test.go b/server/imageproc/images_test.go new file mode 100644 index 00000000..5772fcbb --- /dev/null +++ b/server/imageproc/images_test.go @@ -0,0 +1,344 @@ +package imageproc + +import ( + "bytes" + "image" + "image/png" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestAspectRatios(t *testing.T) { + type aspectCase struct { + MaxTiles int + Expected []image.Point + } + + cases := []aspectCase{ + { + MaxTiles: 1, + Expected: []image.Point{{1, 1}}, + }, + { + MaxTiles: 2, + Expected: []image.Point{{1, 1}, {1, 2}, {2, 1}}, + }, + { + MaxTiles: 3, + Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {2, 1}, {3, 1}}, + }, + { + MaxTiles: 4, + Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {1, 4}, {2, 1}, {2, 2}, {3, 1}, {4, 1}}, + }, + } + + for _, c := range cases { + actual := GetSupportedAspectRatios(c.MaxTiles) + + if diff := cmp.Diff(actual, c.Expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + } +} + +func TestGetImageSizeFitToCanvas(t *testing.T) { + type imageSizeCase struct { + ImageRect image.Point + CanvasRect image.Point + TileSize int + Expected image.Point + } + + cases := []imageSizeCase{ + { + ImageRect: image.Point{400, 400}, + CanvasRect: image.Point{640, 480}, + TileSize: 200, + Expected: image.Point{400, 400}, + }, + { + ImageRect: image.Point{1024, 768}, + CanvasRect: image.Point{640, 480}, + TileSize: 200, + Expected: image.Point{640, 480}, + }, + { + ImageRect: image.Point{500, 500}, + CanvasRect: image.Point{1000, 1000}, + TileSize: 750, + Expected: image.Point{750, 750}, + }, + { + ImageRect: image.Point{500, 1000}, + CanvasRect: image.Point{2000, 2000}, + TileSize: 2000, + Expected: image.Point{1000, 2000}, + }, + { + ImageRect: image.Point{4000, 3000}, + CanvasRect: image.Point{2000, 1000}, + TileSize: 1000, + Expected: image.Point{1333, 1000}, + }, + { + ImageRect: image.Point{667, 1000}, + CanvasRect: image.Point{1000, 1000}, + TileSize: 560, + Expected: image.Point{667, 1000}, + }, + } + + for _, c := range cases { + actual := getImageSizeFitToCanvas(c.ImageRect, c.CanvasRect, c.TileSize) + + if actual != c.Expected { + t.Errorf("incorrect image rect: '%#v'. expected: '%#v'", actual, c.Expected) + } + } +} + +func TestGetOptimalTiledCanvas(t *testing.T) { + type tiledCanvasSizeCase struct { + ImageSize image.Point + MaxImageTiles int + TileSize int + Expected image.Point + } + + cases := []tiledCanvasSizeCase{ + { + ImageSize: image.Point{1024, 768}, + MaxImageTiles: 4, + TileSize: 1000, + Expected: image.Point{2000, 1000}, + }, + { + ImageSize: image.Point{1024, 768}, + MaxImageTiles: 4, + TileSize: 560, + Expected: image.Point{1120, 1120}, + }, + } + + for _, c := range cases { + actual := getOptimalTiledCanvas(c.ImageSize, c.MaxImageTiles, c.TileSize) + + if actual != c.Expected { + t.Errorf("incorrect tiled canvas: '%#v'. expected: '%#v'", actual, c.Expected) + } + } +} + +func TestSplitToTiles(t *testing.T) { + type splitCase struct { + TestImage image.Image + NumTilesSize image.Point + Expected []image.Image + } + + cases := []splitCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + NumTilesSize: image.Point{1, 1}, + Expected: []image.Image{image.NewRGBA(image.Rect(0, 0, 1024, 768))}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 500)), + NumTilesSize: image.Point{2, 1}, + Expected: []image.Image{ + image.NewRGBA(image.Rect(0, 0, 500, 500)), + image.NewRGBA(image.Rect(500, 0, 1000, 500)), + }, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 1000)), + NumTilesSize: image.Point{2, 2}, + Expected: []image.Image{ + image.NewRGBA(image.Rect(0, 0, 500, 500)), + image.NewRGBA(image.Rect(500, 0, 1000, 500)), + image.NewRGBA(image.Rect(0, 500, 500, 1000)), + image.NewRGBA(image.Rect(500, 500, 1000, 1000)), + }, + }, + } + + for _, c := range cases { + actual := splitToTiles(c.TestImage, c.NumTilesSize) + + if len(actual) != len(c.Expected) { + t.Errorf("incorrect number of images '%d': expected: '%d'", len(actual), len(c.Expected)) + } + + for i := range actual { + if actual[i].Bounds() != c.Expected[i].Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual[i].Bounds(), c.Expected[i].Bounds()) + } + } + } +} + +func TestResize(t *testing.T) { + type resizeCase struct { + TestImage image.Image + OutputSize image.Point + MaxImageTiles int + ExpectedImage image.Image + ExpectedAspectRatio image.Point + } + + cases := []resizeCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)), + OutputSize: image.Point{100, 100}, + MaxImageTiles: 1, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)), + ExpectedAspectRatio: image.Point{1, 1}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)), + OutputSize: image.Point{100, 100}, + MaxImageTiles: 2, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)), + ExpectedAspectRatio: image.Point{1, 1}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), + OutputSize: image.Point{560, 560}, + MaxImageTiles: 4, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 560, 560)), + ExpectedAspectRatio: image.Point{1, 1}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 2560, 1920)), + OutputSize: image.Point{560, 560}, + MaxImageTiles: 4, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1120, 840)), + ExpectedAspectRatio: image.Point{2, 2}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + OutputSize: image.Point{560, 560}, + MaxImageTiles: 4, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + ExpectedAspectRatio: image.Point{2, 2}, + }, + } + + for _, c := range cases { + actualImage, actualAspectRatio := ResizeImage(c.TestImage, "png", c.OutputSize, c.MaxImageTiles) + + if actualImage.Bounds() != c.ExpectedImage.Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actualImage.Bounds(), c.ExpectedImage.Bounds()) + } + + if actualAspectRatio != c.ExpectedAspectRatio { + t.Errorf("aspect ratio incorrect: '%#v': expected: '%#v'", actualAspectRatio, c.ExpectedAspectRatio) + } + } +} + +func TestPad(t *testing.T) { + type padCase struct { + TestImage image.Image + OutputSize image.Point + AspectRatio image.Point + Expected image.Image + } + + cases := []padCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 667)), + OutputSize: image.Point{560, 560}, + AspectRatio: image.Point{2, 2}, + Expected: image.NewRGBA(image.Rect(0, 0, 1120, 1120)), + }, + } + + for _, c := range cases { + actual := PadImage(c.TestImage, c.OutputSize, c.AspectRatio) + + if actual.Bounds() != c.Expected.Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) + } + } +} + +func TestPackImages(t *testing.T) { + type packCase struct { + TestImage image.Image + AspectRatio image.Point + ExpectedVals int + } + + mean := [3]float32{0.48145466, 0.4578275, 0.40821073} + std := [3]float32{0.26862954, 0.26130258, 0.27577711} + + cases := []packCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 1120)), + AspectRatio: image.Point{2, 2}, + ExpectedVals: 2 * 2 * 3 * 560 * 560, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 560, 560)), + AspectRatio: image.Point{1, 1}, + ExpectedVals: 1 * 1 * 3 * 560 * 560, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 560)), + AspectRatio: image.Point{1, 2}, + ExpectedVals: 1 * 2 * 3 * 560 * 560, + }, + } + + for _, c := range cases { + actualVals := PackImages(c.TestImage, c.AspectRatio, mean, std) + if len(actualVals) != c.ExpectedVals { + t.Errorf("packed image size incorrect: '%d': expected: '%d'", len(actualVals), c.ExpectedVals) + } + } +} + +func TestPreprocess(t *testing.T) { + type preprocessCase struct { + TestImage image.Image + ExpectedVals int + ExpectedAspectRatioID int + } + + cases := []preprocessCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), + ExpectedVals: 0, + ExpectedAspectRatioID: 1, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + ExpectedVals: 0, + ExpectedAspectRatioID: 6, + }, + } + + for _, c := range cases { + var buf bytes.Buffer + err := png.Encode(&buf, c.TestImage) + if err != nil { + t.Fatal(err) + } + + imgData, aspectRatioID, err := Preprocess(buf.Bytes()) + if err != nil { + t.Fatalf("error processing: %q", err) + } + + if len(imgData) == 0 { + t.Errorf("no image data returned") + } + + if aspectRatioID != c.ExpectedAspectRatioID { + t.Errorf("aspect ratio incorrect: '%d': expected: '%d'", aspectRatioID, c.ExpectedAspectRatioID) + } + } +} diff --git a/server/model.go b/server/model.go index 124693d3..4926d6ce 100644 --- a/server/model.go +++ b/server/model.go @@ -194,7 +194,9 @@ func parseFromFile(ctx context.Context, command string, baseLayers []*layerGGML, mediatype := "application/vnd.ollama.image.model" if ggml.Name() == "ggla" || ggml.KV().Kind() == "adapter" { mediatype = "application/vnd.ollama.image.adapter" - } else if ggml.KV().Architecture() == "clip" { + } + + if _, ok := ggml.KV()[fmt.Sprintf("%s.vision.block_count", ggml.KV().Architecture())]; ok || ggml.KV().Kind() == "projector" { mediatype = "application/vnd.ollama.image.projector" } diff --git a/server/prompt.go b/server/prompt.go index be0d4969..1d6f5cdb 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -3,24 +3,42 @@ package server import ( "bytes" "context" + "encoding/binary" + "errors" + "fmt" "log/slog" + "strings" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/imageproc" "github.com/ollama/ollama/template" ) type tokenizeFunc func(context.Context, string) ([]int, error) +var errTooManyImages = errors.New("vision model only supports a single image per message") + // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message - // always include the last message + + isMllama := checkMllamaModelFamily(m) + n := len(msgs) - 1 // in reverse, find all messages that fit into context window - for i := n - 1; i >= 0; i-- { + for i := n; i >= 0; i-- { + if isMllama && len(msgs[i].Images) > 1 { + return "", nil, errTooManyImages + } + + // always include the last message + if i == n { + continue + } + system = make([]api.Message, 0) for j := range i { if msgs[j].Role == "system" { @@ -38,16 +56,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. return "", nil, err } - c := len(s) + ctxLen := len(s) if m.ProjectorPaths != nil { for _, m := range msgs[i:] { // images are represented as 768 sized embeddings // TODO: get embedding length from project metadata - c += 768 * len(m.Images) + ctxLen += 768 * len(m.Images) } } - if c > opts.NumCtx { + if ctxLen > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { @@ -55,20 +73,70 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - // truncate any messages that do not fit into the context window - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil { - return "", nil, err + currMsgIdx := n + + if isMllama { + lastMsgIdx := len(msgs) - 1 + for i := lastMsgIdx; i >= currMsgIdx; i-- { + if len(msgs[i].Images) > 0 { + data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0]) + if err != nil { + return "", nil, err + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } + + imgData := llm.ImageData{ + Data: buf.Bytes(), + AspectRatioID: aspectRatioID, + } + + msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content) + images = append(images, imgData) + break + } + } + } else { + for cnt, msg := range msgs[currMsgIdx:] { + prefix := "" + prompt := msg.Content + for _, i := range msg.Images { + imgData := llm.ImageData{ + ID: len(images), + Data: i, + } + + imgTag := fmt.Sprintf("[img-%d]", imgData.ID) + if !strings.Contains(prompt, "[img]") { + prefix += imgTag + } else { + prompt = strings.Replace(prompt, "[img]", imgTag, 1) + } + + images = append(images, imgData) + } + msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt) + } } - for _, m := range msgs[n:] { - for _, i := range m.Images { - images = append(images, llm.ImageData{ - ID: len(images), - Data: i, - }) - } + // truncate any messages that do not fit into the context window + var b bytes.Buffer + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { + return "", nil, err } return b.String(), images, nil } + +func checkMllamaModelFamily(m *Model) bool { + for _, arch := range m.Config.ModelFamilies { + if arch == "mllama" { + return true + } + } + return false +} diff --git a/server/prompt_test.go b/server/prompt_test.go index 5fe3d4c5..123a2081 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -3,6 +3,8 @@ package server import ( "bytes" "context" + "image" + "image/png" "testing" "github.com/google/go-cmp/cmp" @@ -13,18 +15,53 @@ import ( func TestChatPrompt(t *testing.T) { type expect struct { - prompt string - images [][]byte + prompt string + images [][]byte + aspectRatioID int + error error + } + + tmpl, err := template.Parse(` +{{- if .System }}{{ .System }} {{ end }} +{{- if .Prompt }}{{ .Prompt }} {{ end }} +{{- if .Response }}{{ .Response }} {{ end }}`) + if err != nil { + t.Fatal(err) + } + visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} + mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}} + + createImg := func(width, height int) ([]byte, error) { + img := image.NewRGBA(image.Rect(0, 0, 5, 5)) + var buf bytes.Buffer + + if err := png.Encode(&buf, img); err != nil { + return nil, err + } + + return buf.Bytes(), nil + } + + imgBuf, err := createImg(5, 5) + if err != nil { + t.Fatal(err) + } + + imgBuf2, err := createImg(6, 6) + if err != nil { + t.Fatal(err) } cases := []struct { name string + model Model limit int msgs []api.Message expect }{ { name: "messages", + model: visionModel, limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -37,6 +74,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "truncate messages", + model: visionModel, limit: 1, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -49,6 +87,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "truncate messages with image", + model: visionModel, limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -64,6 +103,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "truncate messages with images", + model: visionModel, limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -79,6 +119,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "messages with images", + model: visionModel, limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -95,6 +136,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "message with image tag", + model: visionModel, limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, @@ -111,6 +153,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "messages with interleaved images", + model: visionModel, limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -129,6 +172,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "truncate message with interleaved images", + model: visionModel, limit: 1024, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -146,6 +190,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "message with system prompt", + model: visionModel, limit: 2048, msgs: []api.Message{ {Role: "system", Content: "You are the Test Who Lived."}, @@ -159,6 +204,7 @@ func TestChatPrompt(t *testing.T) { }, { name: "out of order system", + model: visionModel, limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -170,23 +216,113 @@ func TestChatPrompt(t *testing.T) { prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", }, }, - } - - tmpl, err := template.Parse(` -{{- if .System }}{{ .System }} {{ end }} -{{- if .Prompt }}{{ .Prompt }} {{ end }} -{{- if .Response }}{{ .Response }} {{ end }}`) - if err != nil { - t.Fatal(err) + { + name: "multiple images same prompt", + model: visionModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}}, + }, + expect: expect{ + prompt: "[img-0][img-1] Compare these two pictures of hotdogs ", + images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")}, + }, + }, + { + name: "messages with mllama (no images)", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", + }, + }, + { + name: "messages with mllama single prompt", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}}, + }, + expect: expect{ + prompt: "<|image|>How many hotdogs are in this image? ", + images: [][]byte{imgBuf}, + aspectRatioID: 1, + }, + }, + { + name: "messages with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{imgBuf}, + aspectRatioID: 1, + }, + }, + { + name: "multiple messages with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{imgBuf2}, + aspectRatioID: 1, + }, + }, + { + name: "earlier image with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}}, + {Role: "assistant", Content: "There are four hotdogs."}, + {Role: "user", Content: "Which ones have mustard?"}, + }, + expect: expect{ + prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ", + images: [][]byte{imgBuf}, + aspectRatioID: 1, + }, + }, + { + name: "too many images with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}}, + }, + expect: expect{ + error: errTooManyImages, + }, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} + model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) - if err != nil { + if tt.error == nil && err != nil { t.Fatal(err) + } else if tt.error != nil && err != tt.error { + t.Fatalf("expected err '%q', got '%q'", tt.error, err) } if diff := cmp.Diff(prompt, tt.prompt); diff != "" { @@ -202,8 +338,14 @@ func TestChatPrompt(t *testing.T) { t.Errorf("expected ID %d, got %d", i, images[i].ID) } - if !bytes.Equal(images[i].Data, tt.images[i]) { - t.Errorf("expected %q, got %q", tt.images[i], images[i]) + if len(model.Config.ModelFamilies) == 0 { + if !bytes.Equal(images[i].Data, tt.images[i]) { + t.Errorf("expected %q, got %q", tt.images[i], images[i].Data) + } + } else { + if images[i].AspectRatioID != tt.aspectRatioID { + t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID) + } } } }) diff --git a/server/routes.go b/server/routes.go index c2b9b241..7aff9235 100644 --- a/server/routes.go +++ b/server/routes.go @@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + model, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == "invalid model name": + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + // expire the runner if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == "invalid model name": - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } s.sched.expireRunner(model) c.JSON(http.StatusOK, api.GenerateResponse{ @@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() + // load the model if req.Prompt == "" { c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + isMllama := checkMllamaModelFamily(model) + if isMllama && len(req.Images) > 1 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"}) + return + } + images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { images[i] = llm.ImageData{ID: i, Data: req.Images[i]} @@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { } for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + if isMllama { + msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"}) + } else { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } } values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 9cadf56a..53501cc6 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) { t.Run("missing body", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, nil) - if w.Code != http.StatusBadRequest { - t.Errorf("expected status 400, got %d", w.Code) + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) } - if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) t.Run("missing model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{}) - if w.Code != http.StatusBadRequest { - t.Errorf("expected status 400, got %d", w.Code) + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) } - if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) diff --git a/server/routes_test.go b/server/routes_test.go index f7a7a22b..bd5b56af 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -562,7 +562,7 @@ func TestShow(t *testing.T) { Modelfile: fmt.Sprintf( "FROM %s\nFROM %s", createBinFile(t, llm.KV{"general.architecture": "test"}, nil), - createBinFile(t, llm.KV{"general.architecture": "clip"}, nil), + createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil), ), }) diff --git a/template/template.go b/template/template.go index 5dc484f4..5c886cac 100644 --- a/template/template.go +++ b/template/template.go @@ -5,7 +5,6 @@ import ( "embed" "encoding/json" "errors" - "fmt" "io" "math" "slices" @@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error { // into a single message. collate also collects and returns all system messages. // collate mutates message content adding image tags ([img-%d]) as needed func collate(msgs []api.Message) (string, []*api.Message) { - var n int - var system []string var collated []*api.Message for i := range msgs { msg := msgs[i] - for range msg.Images { - imageTag := fmt.Sprintf("[img-%d]", n) - if !strings.Contains(msg.Content, "[img]") { - msg.Content = strings.TrimSpace("[img] " + msg.Content) - } - - msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1) - n++ - } - if msg.Role == "system" { system = append(system, msg.Content) } diff --git a/template/template_test.go b/template/template_test.go index 113e0683..616bef6a 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -317,45 +317,6 @@ What is your name?<|im_end|> <|im_start|>assistant `, }, - { - "moondream", - []template{ - // this does not have a "no response" test because it's impossible to render the same output - {"response", `{{ if .Prompt }}Question: {{ .Prompt }} - -{{ end }}Answer: {{ .Response }} - -`}, - {"messages", ` -{{- range .Messages }} -{{- if eq .Role "user" }}Question: {{ .Content }} - -{{ else if eq .Role "assistant" }}Answer: {{ .Content }} - -{{ end }} -{{- end }}Answer: `}, - }, - Values{ - Messages: []api.Message{ - {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}}, - {Role: "assistant", Content: "It's a hot dog."}, - {Role: "user", Content: "What's in _this_ image?"}, - {Role: "user", Images: []api.ImageData{[]byte("")}}, - {Role: "user", Content: "Is it a hot dog?"}, - }, - }, - `Question: [img-0] What's in this image? - -Answer: It's a hot dog. - -Question: What's in _this_ image? - -[img-1] - -Is it a hot dog? - -Answer: `, - }, } for _, tt := range cases {