31 lines
1.2 KiB
Diff
31 lines
1.2 KiB
Diff
|
From dadbed99e65252d79f81101a392d0d6497b86caa Mon Sep 17 00:00:00 2001
|
||
|
From: Shouzheng Liu <lshzh.hi@gmail.com>
|
||
|
Date: Mon, 21 Aug 2023 06:59:29 -0400
|
||
|
Subject: [PATCH] metal : fix synchronization in new matrix multiplication
|
||
|
kernel (#2686)
|
||
|
|
||
|
---
|
||
|
ggml-metal.metal | 3 ++-
|
||
|
1 file changed, 2 insertions(+), 1 deletion(-)
|
||
|
|
||
|
diff --git a/ggml-metal.metal b/ggml-metal.metal
|
||
|
index 3f31252..88d48f6 100644
|
||
|
--- a/ggml-metal.metal
|
||
|
+++ b/ggml-metal.metal
|
||
|
@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||
|
for (int i = 0; i < 8; i++) {
|
||
|
+ threadgroup_barrier(mem_flags::mem_device);
|
||
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||
|
}
|
||
|
|
||
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
+ threadgroup_barrier(mem_flags::mem_device);
|
||
|
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||
|
if (sgitg==0) {
|
||
|
for (int i = 0; i < n_rows; i++) {
|
||
|
--
|
||
|
2.41.0
|
||
|
|