410 lines
15 KiB
Diff
410 lines
15 KiB
Diff
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||
|
From: Michael Yang <mxyng@pm.me>
|
||
|
Date: Thu, 17 Oct 2024 17:19:25 -0700
|
||
|
Subject: [PATCH] add unpad operator
|
||
|
|
||
|
---
|
||
|
ggml/include/ggml.h | 10 ++++
|
||
|
ggml/src/ggml-cuda.cu | 4 ++
|
||
|
ggml/src/ggml-cuda/pad.cu | 46 +++++++++++++++++++
|
||
|
ggml/src/ggml-cuda/pad.cuh | 1 +
|
||
|
ggml/src/ggml-metal.m | 33 ++++++++++++++
|
||
|
ggml/src/ggml-metal.metal | 45 ++++++++++++++++++
|
||
|
ggml/src/ggml.c | 93 +++++++++++++++++++++++++++++++++++++-
|
||
|
7 files changed, 230 insertions(+), 2 deletions(-)
|
||
|
|
||
|
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
|
||
|
index ce3d92cb..962cb5f7 100644
|
||
|
--- a/ggml/include/ggml.h
|
||
|
+++ b/ggml/include/ggml.h
|
||
|
@@ -506,6 +506,7 @@ extern "C" {
|
||
|
GGML_OP_POOL_2D_BACK,
|
||
|
GGML_OP_UPSCALE, // nearest interpolate
|
||
|
GGML_OP_PAD,
|
||
|
+ GGML_OP_UNPAD,
|
||
|
GGML_OP_ARANGE,
|
||
|
GGML_OP_TIMESTEP_EMBEDDING,
|
||
|
GGML_OP_ARGSORT,
|
||
|
@@ -1764,6 +1765,15 @@ extern "C" {
|
||
|
int p2,
|
||
|
int p3);
|
||
|
|
||
|
+ // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
|
||
|
+ GGML_API struct ggml_tensor * ggml_unpad(
|
||
|
+ struct ggml_context * ctx,
|
||
|
+ struct ggml_tensor * a,
|
||
|
+ int p0,
|
||
|
+ int p1,
|
||
|
+ int p2,
|
||
|
+ int p3);
|
||
|
+
|
||
|
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||
|
// timesteps: [N,]
|
||
|
// return: [N, dim]
|
||
|
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
|
||
|
index fe77b81c..6e84af56 100644
|
||
|
--- a/ggml/src/ggml-cuda.cu
|
||
|
+++ b/ggml/src/ggml-cuda.cu
|
||
|
@@ -2270,6 +2270,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||
|
case GGML_OP_PAD:
|
||
|
ggml_cuda_op_pad(ctx, dst);
|
||
|
break;
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
+ ggml_cuda_op_unpad(ctx, dst);
|
||
|
+ break;
|
||
|
case GGML_OP_ARANGE:
|
||
|
ggml_cuda_op_arange(ctx, dst);
|
||
|
break;
|
||
|
@@ -2992,6 +2995,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||
|
case GGML_OP_GROUP_NORM:
|
||
|
case GGML_OP_UPSCALE:
|
||
|
case GGML_OP_PAD:
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
case GGML_OP_ARANGE:
|
||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||
|
case GGML_OP_LEAKY_RELU:
|
||
|
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
|
||
|
index aba539e8..39fd4b16 100644
|
||
|
--- a/ggml/src/ggml-cuda/pad.cu
|
||
|
+++ b/ggml/src/ggml-cuda/pad.cu
|
||
|
@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||
|
}
|
||
|
+
|
||
|
+static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
|
||
|
+ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
|
||
|
+ // blockIdx.y: idx of ne1
|
||
|
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||
|
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||
|
+ if (nidx >= ne0) {
|
||
|
+ return;
|
||
|
+ }
|
||
|
+
|
||
|
+ // operation
|
||
|
+ int offset_dst =
|
||
|
+ nidx +
|
||
|
+ blockIdx.y * ne0 +
|
||
|
+ blockIdx.z * ne0 * gridDim.y;
|
||
|
+ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
|
||
|
+ int offset_src =
|
||
|
+ nidx +
|
||
|
+ blockIdx.y * ne00 +
|
||
|
+ blockIdx.z * ne00 * ne01;
|
||
|
+ dst[offset_dst] = x[offset_src];
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+static void unpad_f32_cuda(const float * x, float * dst,
|
||
|
+ const int ne00, const int ne01, const int ne02, const int ne03,
|
||
|
+ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||
|
+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||
|
+ dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||
|
+ unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||
|
+}
|
||
|
+
|
||
|
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||
|
+ const ggml_tensor * src0 = dst->src[0];
|
||
|
+ const float * src0_d = (const float *)src0->data;
|
||
|
+ float * dst_d = (float *)dst->data;
|
||
|
+ cudaStream_t stream = ctx.stream();
|
||
|
+
|
||
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||
|
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||
|
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||
|
+
|
||
|
+ unpad_f32_cuda(src0_d, dst_d,
|
||
|
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||
|
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||
|
+}
|
||
|
diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
|
||
|
index 8fd386b0..e2ededc3 100644
|
||
|
--- a/ggml/src/ggml-cuda/pad.cuh
|
||
|
+++ b/ggml/src/ggml-cuda/pad.cuh
|
||
|
@@ -3,3 +3,4 @@
|
||
|
#define CUDA_PAD_BLOCK_SIZE 256
|
||
|
|
||
|
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||
|
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||
|
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
|
||
|
index 829c5e39..25702d85 100644
|
||
|
--- a/ggml/src/ggml-metal.m
|
||
|
+++ b/ggml/src/ggml-metal.m
|
||
|
@@ -193,6 +193,7 @@
|
||
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||
|
+ GGML_METAL_KERNEL_TYPE_UNPAD_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||
|
@@ -689,6 +690,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32, unpad_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||
|
@@ -846,6 +848,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||
|
return false;
|
||
|
case GGML_OP_UPSCALE:
|
||
|
case GGML_OP_PAD:
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
case GGML_OP_ARANGE:
|
||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||
|
case GGML_OP_ARGSORT:
|
||
|
@@ -2655,6 +2658,36 @@ static void ggml_metal_encode_node(
|
||
|
|
||
|
const int nth = MIN(1024, ne0);
|
||
|
|
||
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||
|
+ } break;
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
+ {
|
||
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||
|
+
|
||
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
|
||
|
+
|
||
|
+ [encoder setComputePipelineState:pipeline];
|
||
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||
|
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||
|
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||
|
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||
|
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||
|
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||
|
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||
|
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||
|
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||
|
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||
|
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||
|
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||
|
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||
|
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||
|
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||
|
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||
|
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||
|
+
|
||
|
+ const int nth = MIN(1024, ne0);
|
||
|
+
|
||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||
|
} break;
|
||
|
case GGML_OP_ARANGE:
|
||
|
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
|
||
|
index 2b200032..09887511 100644
|
||
|
--- a/ggml/src/ggml-metal.metal
|
||
|
+++ b/ggml/src/ggml-metal.metal
|
||
|
@@ -2029,6 +2029,51 @@ kernel void kernel_pad_f32(
|
||
|
}
|
||
|
}
|
||
|
|
||
|
+kernel void kernel_unpad_f32(
|
||
|
+ device const char * src0,
|
||
|
+ device char * dst,
|
||
|
+ constant int64_t & ne00,
|
||
|
+ constant int64_t & ne01,
|
||
|
+ constant int64_t & ne02,
|
||
|
+ constant int64_t & ne03,
|
||
|
+ constant uint64_t & nb00,
|
||
|
+ constant uint64_t & nb01,
|
||
|
+ constant uint64_t & nb02,
|
||
|
+ constant uint64_t & nb03,
|
||
|
+ constant int64_t & ne0,
|
||
|
+ constant int64_t & ne1,
|
||
|
+ constant int64_t & ne2,
|
||
|
+ constant int64_t & ne3,
|
||
|
+ constant uint64_t & nb0,
|
||
|
+ constant uint64_t & nb1,
|
||
|
+ constant uint64_t & nb2,
|
||
|
+ constant uint64_t & nb3,
|
||
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
||
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
||
|
+
|
||
|
+ const int64_t i3 = tgpig.z;
|
||
|
+ const int64_t i2 = tgpig.y;
|
||
|
+ const int64_t i1 = tgpig.x;
|
||
|
+
|
||
|
+ const int64_t i03 = i3;
|
||
|
+ const int64_t i02 = i2;
|
||
|
+ const int64_t i01 = i1;
|
||
|
+
|
||
|
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||
|
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
||
|
+
|
||
|
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||
|
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||
|
+ if (i0 < ne00) {
|
||
|
+ dst_ptr[i0] = src0_ptr[i0];
|
||
|
+ }
|
||
|
+ }
|
||
|
+
|
||
|
+ return;
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
kernel void kernel_arange_f32(
|
||
|
device char * dst,
|
||
|
constant int64_t & ne0,
|
||
|
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
|
||
|
index bcbc32d9..f4864ac8 100644
|
||
|
--- a/ggml/src/ggml.c
|
||
|
+++ b/ggml/src/ggml.c
|
||
|
@@ -2997,6 +2997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||
|
"POOL_2D_BACK",
|
||
|
"UPSCALE",
|
||
|
"PAD",
|
||
|
+ "UNPAD",
|
||
|
"ARANGE",
|
||
|
"TIMESTEP_EMBEDDING",
|
||
|
"ARGSORT",
|
||
|
@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||
|
"OPT_STEP_ADAMW",
|
||
|
};
|
||
|
|
||
|
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||
|
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||
|
|
||
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||
|
"none",
|
||
|
@@ -3091,6 +3092,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||
|
"pool_2d_back(x)",
|
||
|
"upscale(x)",
|
||
|
"pad(x)",
|
||
|
+ "unpad(x)",
|
||
|
"arange(start, stop, step)",
|
||
|
"timestep_embedding(timesteps, dim, max_period)",
|
||
|
"argsort(x)",
|
||
|
@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||
|
"adamw(x)",
|
||
|
};
|
||
|
|
||
|
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||
|
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||
|
|
||
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||
|
|
||
|
@@ -6955,6 +6957,32 @@ struct ggml_tensor * ggml_pad(
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
+// ggml_unpad
|
||
|
+
|
||
|
+struct ggml_tensor * ggml_unpad(
|
||
|
+ struct ggml_context * ctx,
|
||
|
+ struct ggml_tensor * a,
|
||
|
+ int p0, int p1, int p2, int p3) {
|
||
|
+ bool is_node = false;
|
||
|
+
|
||
|
+ if (a->grad) {
|
||
|
+ GGML_ABORT("fatal error"); // TODO: implement backward
|
||
|
+ is_node = true;
|
||
|
+ }
|
||
|
+
|
||
|
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
|
||
|
+ a->ne[0] - p0,
|
||
|
+ a->ne[1] - p1,
|
||
|
+ a->ne[2] - p2,
|
||
|
+ a->ne[3] - p3);
|
||
|
+
|
||
|
+ result->op = GGML_OP_UNPAD;
|
||
|
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||
|
+ result->src[0] = a;
|
||
|
+
|
||
|
+ return result;
|
||
|
+}
|
||
|
+
|
||
|
// ggml_arange
|
||
|
|
||
|
struct ggml_tensor * ggml_arange(
|
||
|
@@ -15312,6 +15340,58 @@ static void ggml_compute_forward_pad(
|
||
|
}
|
||
|
}
|
||
|
|
||
|
+static void ggml_compute_forward_unpad_f32(
|
||
|
+ const struct ggml_compute_params *params,
|
||
|
+ struct ggml_tensor *dst) {
|
||
|
+
|
||
|
+ const struct ggml_tensor * src0 = dst->src[0];
|
||
|
+
|
||
|
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||
|
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||
|
+
|
||
|
+ const int ith = params->ith;
|
||
|
+ const int nth = params->nth;
|
||
|
+
|
||
|
+ GGML_TENSOR_UNARY_OP_LOCALS
|
||
|
+
|
||
|
+ float * dst_ptr = (float *) dst->data;
|
||
|
+
|
||
|
+ // TODO: optimize
|
||
|
+
|
||
|
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||
|
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||
|
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||
|
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||
|
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||
|
+
|
||
|
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||
|
+
|
||
|
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||
|
+ dst_ptr[dst_idx] = *src_ptr;
|
||
|
+ }
|
||
|
+ }
|
||
|
+ }
|
||
|
+ }
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+static void ggml_compute_forward_unpad(
|
||
|
+ const struct ggml_compute_params * params,
|
||
|
+ struct ggml_tensor * dst) {
|
||
|
+
|
||
|
+ const struct ggml_tensor * src0 = dst->src[0];
|
||
|
+
|
||
|
+ switch (src0->type) {
|
||
|
+ case GGML_TYPE_F32:
|
||
|
+ {
|
||
|
+ ggml_compute_forward_unpad_f32(params, dst);
|
||
|
+ } break;
|
||
|
+ default:
|
||
|
+ {
|
||
|
+ GGML_ABORT("fatal error");
|
||
|
+ }
|
||
|
+ }
|
||
|
+}
|
||
|
|
||
|
// ggml_compute_forward_arange
|
||
|
|
||
|
@@ -17294,6 +17374,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||
|
{
|
||
|
ggml_compute_forward_pad(params, tensor);
|
||
|
} break;
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
+ {
|
||
|
+ ggml_compute_forward_unpad(params, tensor);
|
||
|
+ } break;
|
||
|
case GGML_OP_ARANGE:
|
||
|
{
|
||
|
ggml_compute_forward_arange(params, tensor);
|
||
|
@@ -18369,6 +18453,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||
|
{
|
||
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||
|
}
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
+ {
|
||
|
+ GGML_ABORT("fatal error"); // TODO: not implemented
|
||
|
+ }
|
||
|
case GGML_OP_ARANGE:
|
||
|
{
|
||
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||
|
@@ -19165,6 +19253,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||
|
} break;
|
||
|
case GGML_OP_UPSCALE:
|
||
|
case GGML_OP_PAD:
|
||
|
+ case GGML_OP_UNPAD:
|
||
|
case GGML_OP_ARANGE:
|
||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||
|
case GGML_OP_ARGSORT:
|