Skip to content

[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009

Open
xuyanwen2012 wants to merge 2 commits intopytorch:mainfrom
sarc-acl:yanwen/pr-amend-staging
Open

[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009
xuyanwen2012 wants to merge 2 commits intopytorch:mainfrom
sarc-acl:yanwen/pr-amend-staging

Conversation

@xuyanwen2012
Copy link
Copy Markdown

Summary

Adds cooperative-matrix (WMMA) drop-in variants of the existing tiled linear_vec / matmul_vec shaders, dispatched automatically when two conditions hold:

  1. The device exposes VK_KHR_cooperative_matrix (checked via a new Adapter::supports_cooperative_matrix() helper)
  2. The output tensor is in buffer storage

When either condition fails, dispatch falls back to the existing tiled shader — no change in behavior for any existing user.

Why

Modern discrete and mobile GPUs (AMD RDNA3+, NVIDIA Turing+) expose hardware matrix-multiply-accumulate tiles through the VK_KHR_cooperative_matrix extension, typically delivering 3–4x throughput on compute-bound GEMM vs software tiling. ExecuTorch's Vulkan backend currently uses linear_vec / matmul_vec (scalar/vector compute tiles) uniformly regardless of device capability, leaving WMMA throughput on the table on capable hardware.

What changes

Area Change
Adapter.h +9 LOC. Adds Adapter::supports_cooperative_matrix() querying the cooperative_matrix_features physical-device field already populated in Device.cpp
New shaders linear_coopmat.glsl (+261) and matmul_coopmat.glsl (+227): fp16×fp16→fp32 cooperative-matrix MMA on 16×16×16 tiles; 64×64 output tile per 512-thread workgroup targeting subgroupSize=64
Linear.cpp / Linear.h Adds add_linear_coopmat_node + pickers; prepack_fp_linear_weight gains a force_buffer parameter so the coopmat path can obtain buffer-stored weights
Matmul.cpp Dispatch branch for both runtime-mat2 and constant-mat2 cases; routes constant-mat2 through the linear path, runtime-mat2 through add_matmul_coopmat_node
cm_utils.{h,cpp} queryCooperativeMatrixProperties() helper that prints the device's supported coopmat configs at startup (diagnostic only)
linear_coopmat_bench.cpp / matmul_coopmat_bench.cpp GPU-timestamp microbenchmarks comparing coopmat vs tiled across BERT/LLM/square shapes

How to test

1. Configure and build the core runtime

cmake . \
    -Bcmake-out-vk \
    --preset "linux" \
    -DCMAKE_INSTALL_PREFIX=cmake-out-vk \
    -DCMAKE_BUILD_TYPE=Release \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DEXECUTORCH_PAL_DEFAULT=posix \
    -DEXECUTORCH_BUILD_VULKAN=ON \
    -DEXECUTORCH_BUILD_TESTS=ON \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_FLAGS="-include algorithm"

cmake --build cmake-out-vk -j$(nproc) --target install --config Release

2. Configure and build the Vulkan custom ops (GEMM tests and benchmarks)

cmake backends/vulkan/test/custom_ops/ \
    -Bcmake-out-vk/backends/vulkan/test/custom_ops \
    -DCMAKE_INSTALL_PREFIX=cmake-out-vk \
    -DCMAKE_BUILD_TYPE=Release \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DEXECUTORCH_ROOT=$(pwd) \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache

cmake --build cmake-out-vk/backends/vulkan/test/custom_ops -j$(nproc)

3. Run the benchmarks on a device supporting VK_KHR_cooperative_matrix

./cmake-out-vk/backends/vulkan/test/custom_ops/linear_coopmat_bench
./cmake-out-vk/backends/vulkan/test/custom_ops/matmul_coopmat_bench

@SS-JIA

Convenience helper that queries VK_KHR_cooperative_matrix feature
support on the physical device. Used by the drop-in coopmat shader
variants to gate dispatch onto the tiled fallback when unsupported.
Adds VK_KHR_cooperative_matrix GLSL variants of the tiled linear and
matmul shaders. Dispatch is gated by
Adapter::supports_cooperative_matrix() and buffer output storage, with
automatic fallback to the tiled shader when unsupported. An M >= 64
guard avoids a known OOB in the current coopmat store; that guard will
be removed once partial-tile bounds checking is added to the shader.

Includes linear_coopmat_bench and matmul_coopmat_bench microbenchmarks
that compare against linear_vec / matmul_vec across BERT and LLM-sized
shapes using Vulkan query-pool timestamps.
Copilot AI review requested due to automatic review settings April 20, 2026 21:59
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19009

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 11 Awaiting Approval, 1 Unrelated Failure

As of commit b26728a with merge base 8ed6e85 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Apr 20, 2026

Hi @xuyanwen2012!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Vulkan cooperative-matrix (VK_KHR_cooperative_matrix / WMMA-style) fast path for linear/matmul when the device supports it and the output is buffer-backed, plus diagnostic tooling and microbenchmarks to compare against the existing tiled (*_vec) shaders.

Changes:

  • Introduces cooperative-matrix GLSL shaders and shader variants for linear and matmul.
  • Adds runtime dispatch branching to select coopmat vs tiled implementations, plus a supports_cooperative_matrix() adapter helper.
  • Adds coopmat diagnostics (cm_utils) and two microbenchmarks for linear/matmul.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
backends/vulkan/runtime/vk_api/Adapter.h Adds supports_cooperative_matrix() feature check.
backends/vulkan/runtime/graph/ops/impl/Matmul.cpp Adds coopmat node and dispatch selection for matmul (including constant-mat2 route via linear).
backends/vulkan/runtime/graph/ops/impl/Linear.h Extends prepack API and declares add_linear_coopmat_node.
backends/vulkan/runtime/graph/ops/impl/Linear.cpp Adds coopmat linear node + selection logic and a force_buffer prepack option.
backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.yaml Registers matmul coopmat shader variants (dtype).
backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl New cooperative-matrix matmul shader (buffer-only).
backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.yaml Registers linear coopmat shader variants (dtype, bias).
backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl New cooperative-matrix linear shader for prepacked weights (buffer-only).
backends/vulkan/test/custom_ops/cm_utils.h Declares cooperative-matrix property query helper for benchmarks/diagnostics.
backends/vulkan/test/custom_ops/cm_utils.cpp Implements cooperative-matrix property enumeration/printing.
backends/vulkan/test/custom_ops/linear_coopmat_bench.cpp Adds linear coopmat vs vec microbenchmark.
backends/vulkan/test/custom_ops/matmul_coopmat_bench.cpp Adds matmul coopmat vs vec microbenchmark.
backends/vulkan/test/custom_ops/CMakeLists.txt Wires new cm_utils + benchmark targets into the custom_ops build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +345 to +356
// Coopmat shader assumes M is a multiple of TILE_M (64) because the store
// does not bounds-check. Fall back to the tiled shader otherwise.
// TODO: remove this guard once the coopmat shader gains partial-tile
// bounds checking.
auto input_sizes = graph.sizes_of(input);
int64_t M = input_sizes.size() >= 2
? input_sizes.at(input_sizes.size() - 2)
: 1;
bool use_coopmat =
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
graph.storage_type_of(out) == utils::kBuffer &&
M >= 64;
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
VK_CHECK_COND(
graph.storage_type_of(out) == utils::kBuffer,
"linear_coopmat requires buffer storage");
Comment on lines +124 to +131
for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) {

// --- Load A tile → shared (single pass) ---
{
uint row = a_row_base + a_row_offset;
uint k_elem = chunkK + a_col * FP16_PER_VEC4;

#ifdef IS_FP16_INPUT
Comment on lines +137 to +145
for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) {

// --- Load A tile -> shared (same as matmul_coopmat) ---
{
uint row = a_row_base + a_row_offset;
uint k_elem = chunkK + a_col * FP16_PER_VEC4;

#ifdef IS_FP16_INPUT
uint k_hv4 = k_elem / 4;
Comment on lines +281 to +286
bool use_coopmat =
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
graph.storage_type_of(out) == utils::kBuffer;
ValueRef packed = prepack_fp_linear_weight(
graph, mat2, /*is_transposed=*/false, B,
/*force_buffer=*/use_coopmat);
Comment on lines +301 to 306
} else if (
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
graph.storage_type_of(out) == utils::kBuffer) {
add_matmul_coopmat_node(graph, mat1, mat2, out);
} else {
add_matmul_tiled_node(graph, mat1, mat2, out);
Comment on lines +255 to +262
inline bool supports_cooperative_matrix() {
#ifdef VK_KHR_cooperative_matrix
return physical_device_.cooperative_matrix_features.cooperativeMatrix ==
VK_TRUE;
#else
return false;
#endif /* VK_KHR_cooperative_matrix */
}
Comment on lines +211 to +223
#ifdef IS_FP16_INPUT
// Convert fp32 accumulator to fp16 for fp16 output buffer
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile =
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]);
coopMatStore(
out_tile, t_output,
gi * N + gj, N,
gl_CooperativeMatrixLayoutRowMajor);
#else
coopMatStore(
result[i][j], t_output,
gi * N + gj, N,
gl_CooperativeMatrixLayoutRowMajor);
Comment on lines +221 to +239
// --- Store result ---
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i);
uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j);
#ifdef IS_FP16_INPUT
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile =
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]);
coopMatStore(
out_tile, t_output,
gi * N + gj, N,
gl_CooperativeMatrixLayoutRowMajor);
#else
coopMatStore(
result[i][j], t_output,
gi * N + gj, N,
gl_CooperativeMatrixLayoutRowMajor);
#endif
}
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants