Skip to content

Commit e52532a

Browse files
Merge pull request #2596 from kabu1204/msl-coop-mat
MSL: add initial cooperative matrix support
2 parents ffb16a2 + 7594d2b commit e52532a

15 files changed

+857
-5
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <metal_stdlib>
2+
#include <simd/simd.h>
3+
#include <metal_simdgroup_matrix>
4+
5+
using namespace metal;
6+
7+
struct SSBO
8+
{
9+
bfloat data[1];
10+
};
11+
12+
kernel void main0(device SSBO& ssbo [[buffer(0)]])
13+
{
14+
simdgroup_bfloat8x8 _21;
15+
simdgroup_load(_21, &ssbo.data[0u], 8u);
16+
simdgroup_bfloat8x8 _22;
17+
simdgroup_load(_22, &ssbo.data[0u], 8u);
18+
simdgroup_bfloat8x8 _23;
19+
simdgroup_load(_23, &ssbo.data[0u], 8u);
20+
simdgroup_bfloat8x8 _24;
21+
simdgroup_multiply_accumulate(_24, _21, _22, _23);
22+
simdgroup_store(_24, &ssbo.data[0u], 8u);
23+
}
24+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#include <metal_stdlib>
2+
#include <simd/simd.h>
3+
#include <metal_simdgroup_matrix>
4+
5+
using namespace metal;
6+
7+
struct SSBO
8+
{
9+
uint data[1];
10+
};
11+
12+
kernel void main0(device SSBO& ssbo [[buffer(0)]])
13+
{
14+
ssbo.data[0u] = uint(sizeof(simdgroup_float8x8::storage_type) / sizeof(float));
15+
}
16+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <metal_stdlib>
2+
#include <simd/simd.h>
3+
#include <metal_simdgroup_matrix>
4+
5+
using namespace metal;
6+
7+
struct SSBO
8+
{
9+
float data[1];
10+
};
11+
12+
kernel void main0(device SSBO& ssbo [[buffer(0)]])
13+
{
14+
simdgroup_float8x8 _20;
15+
simdgroup_load(_20, &ssbo.data[0u], 8u);
16+
simdgroup_store(_20, &ssbo.data[0u], 8u);
17+
simdgroup_float8x8 _21;
18+
simdgroup_load(_21, &ssbo.data[0u], 8u, ulong2(0), true);
19+
simdgroup_store(_21, &ssbo.data[0u], 8u, ulong2(0), true);
20+
}
21+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <metal_stdlib>
2+
#include <simd/simd.h>
3+
#include <metal_simdgroup_matrix>
4+
5+
using namespace metal;
6+
7+
struct SSBO32
8+
{
9+
float data[1];
10+
};
11+
12+
struct SSBO16
13+
{
14+
half data[1];
15+
};
16+
17+
kernel void main0(device SSBO32& ssbo32 [[buffer(0)]], device SSBO16& ssbo16 [[buffer(1)]])
18+
{
19+
simdgroup_float8x8 _30;
20+
simdgroup_load(_30, &ssbo32.data[0u], 8u);
21+
simdgroup_float8x8 _31;
22+
simdgroup_load(_31, &ssbo32.data[0u], 8u);
23+
simdgroup_float8x8 _32;
24+
simdgroup_load(_32, &ssbo32.data[0u], 8u);
25+
simdgroup_float8x8 _33;
26+
simdgroup_multiply_accumulate(_33, _30, _31, _32);
27+
simdgroup_store(_33, &ssbo32.data[0u], 8u);
28+
simdgroup_half8x8 _35;
29+
simdgroup_load(_35, &ssbo16.data[0u], 8u);
30+
simdgroup_half8x8 _36;
31+
simdgroup_load(_36, &ssbo16.data[0u], 8u);
32+
simdgroup_half8x8 _37;
33+
simdgroup_load(_37, &ssbo16.data[0u], 8u);
34+
simdgroup_half8x8 _38;
35+
simdgroup_multiply_accumulate(_38, _35, _36, _37);
36+
simdgroup_store(_38, &ssbo16.data[0u], 8u);
37+
}
38+
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
#include <metal_simdgroup_matrix>
7+
8+
using namespace metal;
9+
10+
template<typename T, size_t Num>
11+
struct spvUnsafeArray
12+
{
13+
T elements[Num ? Num : 1];
14+
15+
thread T& operator [] (size_t pos) thread
16+
{
17+
return elements[pos];
18+
}
19+
constexpr const thread T& operator [] (size_t pos) const thread
20+
{
21+
return elements[pos];
22+
}
23+
24+
device T& operator [] (size_t pos) device
25+
{
26+
return elements[pos];
27+
}
28+
constexpr const device T& operator [] (size_t pos) const device
29+
{
30+
return elements[pos];
31+
}
32+
33+
constexpr const constant T& operator [] (size_t pos) const constant
34+
{
35+
return elements[pos];
36+
}
37+
38+
threadgroup T& operator [] (size_t pos) threadgroup
39+
{
40+
return elements[pos];
41+
}
42+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
43+
{
44+
return elements[pos];
45+
}
46+
};
47+
48+
kernel void main0()
49+
{
50+
threadgroup spvUnsafeArray<uchar, 128> _15;
51+
_15[0u] = uchar(0);
52+
simdgroup_half8x8 _20;
53+
simdgroup_load(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u);
54+
simdgroup_store(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u);
55+
}
56+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
#include <metal_simdgroup_matrix>
7+
8+
using namespace metal;
9+
10+
template<typename T, size_t Num>
11+
struct spvUnsafeArray
12+
{
13+
T elements[Num ? Num : 1];
14+
15+
thread T& operator [] (size_t pos) thread
16+
{
17+
return elements[pos];
18+
}
19+
constexpr const thread T& operator [] (size_t pos) const thread
20+
{
21+
return elements[pos];
22+
}
23+
24+
device T& operator [] (size_t pos) device
25+
{
26+
return elements[pos];
27+
}
28+
constexpr const device T& operator [] (size_t pos) const device
29+
{
30+
return elements[pos];
31+
}
32+
33+
constexpr const constant T& operator [] (size_t pos) const constant
34+
{
35+
return elements[pos];
36+
}
37+
38+
threadgroup T& operator [] (size_t pos) threadgroup
39+
{
40+
return elements[pos];
41+
}
42+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
43+
{
44+
return elements[pos];
45+
}
46+
};
47+
48+
kernel void main0()
49+
{
50+
threadgroup spvUnsafeArray<float, 64> _14;
51+
simdgroup_float8x8 _18;
52+
simdgroup_load(_18, &_14[0u], 8u);
53+
simdgroup_store(_18, &_14[0u], 8u);
54+
simdgroup_float8x8 _19;
55+
simdgroup_load(_19, &_14[0u], 8u, ulong2(0), true);
56+
simdgroup_store(_19, &_14[0u], 8u, ulong2(0), true);
57+
}
58+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: Khronos SPIR-V Tools Assembler; 0
4+
; Bound: 50
5+
; Schema: 0
6+
OpCapability Shader
7+
OpCapability CooperativeMatrixKHR
8+
OpCapability BFloat16TypeKHR
9+
OpCapability BFloat16CooperativeMatrixKHR
10+
OpCapability VulkanMemoryModel
11+
OpExtension "SPV_KHR_cooperative_matrix"
12+
OpExtension "SPV_KHR_bfloat16"
13+
OpExtension "SPV_KHR_vulkan_memory_model"
14+
OpMemoryModel Logical Vulkan
15+
OpEntryPoint GLCompute %main "main"
16+
OpExecutionMode %main LocalSize 32 1 1
17+
OpName %main "main"
18+
OpName %SSBO "SSBO"
19+
OpMemberName %SSBO 0 "data"
20+
OpName %ssbo "ssbo"
21+
OpDecorate %arr_bf16 ArrayStride 2
22+
OpMemberDecorate %SSBO 0 Offset 0
23+
OpDecorate %SSBO Block
24+
OpDecorate %ssbo DescriptorSet 0
25+
OpDecorate %ssbo Binding 0
26+
%void = OpTypeVoid
27+
%3 = OpTypeFunction %void
28+
%bfloat = OpTypeFloat 16 BFloat16KHR
29+
%uint = OpTypeInt 32 0
30+
%uint_0 = OpConstant %uint 0
31+
%uint_1 = OpConstant %uint 1
32+
%uint_2 = OpConstant %uint 2
33+
%uint_3 = OpConstant %uint 3
34+
%uint_8 = OpConstant %uint 8
35+
%arr_bf16 = OpTypeRuntimeArray %bfloat
36+
%SSBO = OpTypeStruct %arr_bf16
37+
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
38+
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
39+
%ptr_ssbo_bf16 = OpTypePointer StorageBuffer %bfloat
40+
%coopmat_bf16_A = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_0
41+
%coopmat_bf16_B = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_1
42+
%coopmat_bf16_acc = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_2
43+
%main = OpFunction %void None %3
44+
%5 = OpLabel
45+
%p0 = OpAccessChain %ptr_ssbo_bf16 %ssbo %uint_0 %uint_0
46+
%bf_A = OpCooperativeMatrixLoadKHR %coopmat_bf16_A %p0 %uint_0 %uint_8
47+
%bf_B = OpCooperativeMatrixLoadKHR %coopmat_bf16_B %p0 %uint_0 %uint_8
48+
%bf_C = OpCooperativeMatrixLoadKHR %coopmat_bf16_acc %p0 %uint_0 %uint_8
49+
%bf_D = OpCooperativeMatrixMulAddKHR %coopmat_bf16_acc %bf_A %bf_B %bf_C
50+
OpCooperativeMatrixStoreKHR %p0 %bf_D %uint_0 %uint_8
51+
OpReturn
52+
OpFunctionEnd
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: Khronos SPIR-V Tools Assembler; 0
4+
; Bound: 24
5+
; Schema: 0
6+
OpCapability Shader
7+
OpCapability CooperativeMatrixKHR
8+
OpCapability VulkanMemoryModel
9+
OpExtension "SPV_KHR_cooperative_matrix"
10+
OpExtension "SPV_KHR_vulkan_memory_model"
11+
OpMemoryModel Logical Vulkan
12+
OpEntryPoint GLCompute %main "main"
13+
OpExecutionMode %main LocalSize 32 1 1
14+
OpName %main "main"
15+
OpName %SSBO "SSBO"
16+
OpMemberName %SSBO 0 "data"
17+
OpName %ssbo "ssbo"
18+
OpDecorate %arr_uint ArrayStride 4
19+
OpMemberDecorate %SSBO 0 Offset 0
20+
OpDecorate %SSBO Block
21+
OpDecorate %ssbo DescriptorSet 0
22+
OpDecorate %ssbo Binding 0
23+
%void = OpTypeVoid
24+
%3 = OpTypeFunction %void
25+
%uint = OpTypeInt 32 0
26+
%float = OpTypeFloat 32
27+
%uint_0 = OpConstant %uint 0
28+
%uint_3 = OpConstant %uint 3
29+
%uint_8 = OpConstant %uint 8
30+
%arr_uint = OpTypeRuntimeArray %uint
31+
%SSBO = OpTypeStruct %arr_uint
32+
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
33+
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
34+
%ptr_ssbo_uint = OpTypePointer StorageBuffer %uint
35+
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
36+
%main = OpFunction %void None %3
37+
%5 = OpLabel
38+
%len = OpCooperativeMatrixLengthKHR %uint %coopmat_a
39+
%p = OpAccessChain %ptr_ssbo_uint %ssbo %uint_0 %uint_0
40+
OpStore %p %len
41+
OpReturn
42+
OpFunctionEnd
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: Khronos SPIR-V Tools Assembler; 0
4+
; Bound: 50
5+
; Schema: 0
6+
OpCapability Shader
7+
OpCapability CooperativeMatrixKHR
8+
OpCapability VulkanMemoryModel
9+
OpExtension "SPV_KHR_cooperative_matrix"
10+
OpExtension "SPV_KHR_vulkan_memory_model"
11+
OpMemoryModel Logical Vulkan
12+
OpEntryPoint GLCompute %main "main"
13+
OpExecutionMode %main LocalSize 32 1 1
14+
OpName %main "main"
15+
OpName %SSBO "SSBO"
16+
OpMemberName %SSBO 0 "data"
17+
OpName %ssbo "ssbo"
18+
OpDecorate %arr_float ArrayStride 4
19+
OpMemberDecorate %SSBO 0 Offset 0
20+
OpDecorate %SSBO Block
21+
OpDecorate %ssbo DescriptorSet 0
22+
OpDecorate %ssbo Binding 0
23+
%void = OpTypeVoid
24+
%3 = OpTypeFunction %void
25+
%float = OpTypeFloat 32
26+
%uint = OpTypeInt 32 0
27+
%uint_0 = OpConstant %uint 0
28+
%uint_1 = OpConstant %uint 1
29+
%uint_2 = OpConstant %uint 2
30+
%uint_3 = OpConstant %uint 3
31+
%uint_8 = OpConstant %uint 8
32+
%arr_float = OpTypeRuntimeArray %float
33+
%SSBO = OpTypeStruct %arr_float
34+
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
35+
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
36+
%ptr_ssbo_float = OpTypePointer StorageBuffer %float
37+
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
38+
%coopmat_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2
39+
%main = OpFunction %void None %3
40+
%5 = OpLabel
41+
; Row-major load from offset 0
42+
%p0 = OpAccessChain %ptr_ssbo_float %ssbo %uint_0 %uint_0
43+
%mat_a = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8
44+
; Row-major store to offset 0
45+
OpCooperativeMatrixStoreKHR %p0 %mat_a %uint_0 %uint_8
46+
; Column-major load from offset 0
47+
%mat_col = OpCooperativeMatrixLoadKHR %coopmat_acc %p0 %uint_1 %uint_8
48+
; Column-major store to offset 0
49+
OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8
50+
OpReturn
51+
OpFunctionEnd

0 commit comments

Comments
 (0)