-
Notifications
You must be signed in to change notification settings - Fork 10
feat: add fused CUDA Adam kernel with benchmark and tests #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
meinie0826
wants to merge
4
commits into
databricks:main
Choose a base branch
from
meinie0826:feat/cuda-adam-kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
ff52ec3
feat: add fused CUDA Adam kernel with benchmark and tests
meinie0826 13bc0b6
perf: optimize quantize reduce path and fix benchmark/tests
meinie0826 fb835fd
perf: optimize quantize reduce path and fix benchmark/tests
meinie0826 829e67f
fix: remove dead warp-reduce code causing register spill in quant path
meinie0826 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,242 @@ | ||
| # SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # bench/bench_cuda_adam.py | ||
| # | ||
| # Performance benchmark: CUDA Adam kernel vs Triton reference. | ||
| # | ||
| # Usage: | ||
| # python bench/bench_cuda_adam.py # all configs, default warmup/iters | ||
| # python bench/bench_cuda_adam.py --iters 200 --warmup 50 | ||
| # python bench/bench_cuda_adam.py --csv results.csv | ||
|
|
||
| import argparse | ||
| import csv | ||
| import sys | ||
| import time | ||
| from dataclasses import dataclass, field | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| _DTYPE_MAP = { | ||
| "bf16": torch.bfloat16, | ||
| "fp16": torch.float16, | ||
| "fp32": torch.float32, | ||
| } | ||
|
|
||
|
|
||
| def _make_state(N: int, device: str, quantize: bool, dtype: torch.dtype): | ||
| param = torch.randn(N, device=device, dtype=dtype) | ||
| grad = torch.randn(N, device=device, dtype=dtype) * 0.01 | ||
| if quantize: | ||
| G = (N + 31) // 32 | ||
| mom = torch.zeros(N, device=device, dtype=torch.int8) | ||
| mom_scales = torch.ones(G, device=device, dtype=torch.float16) * 0.01 | ||
| var = torch.zeros(N, device=device, dtype=torch.uint8) | ||
| var_scales = torch.ones(G, device=device, dtype=torch.float16) * 1e-4 | ||
| else: | ||
| mom = torch.zeros(N, device=device, dtype=dtype) | ||
| mom_scales = torch.empty(0, device=device, dtype=torch.float16) | ||
| var = torch.zeros(N, device=device, dtype=dtype) | ||
| var_scales = torch.empty(0, device=device, dtype=torch.float16) | ||
| return param, grad, mom, mom_scales, var, var_scales | ||
|
|
||
|
|
||
| def _run_triton(param, grad, mom, mom_scales, var, var_scales, | ||
| quantize, decoupled, step): | ||
| import flashoptim.optimizers as opt_mod | ||
| opt_mod._try_load_cuda_adam_ext() | ||
| orig, opt_mod._cuda_adam_ext = opt_mod._cuda_adam_ext, None | ||
| try: | ||
| opt_mod._fused_adam_step( | ||
| mom, mom_scales, var, var_scales, param, grad, None, | ||
| 1e-3, 0.9, 0.999, 1e-8, 0.01, decoupled, step, | ||
| quantize_optim_states=quantize, | ||
| ) | ||
| finally: | ||
| opt_mod._cuda_adam_ext = orig | ||
|
|
||
|
|
||
| def _run_cuda(param, grad, mom, mom_scales, var, var_scales, | ||
| quantize, decoupled, step): | ||
| import flashoptim._cuda_adam as ext | ||
| ext.adam_step( | ||
| mom, mom_scales, var, var_scales, param, grad, None, | ||
| 1e-3, 0.9, 0.999, 1e-8, 0.01, step, | ||
| quantize, decoupled, 32, | ||
| ) | ||
|
|
||
|
|
||
| def _bench(fn, warmup: int, iters: int) -> float: | ||
| """Return median wall-time per call in milliseconds (GPU-synchronised).""" | ||
| torch.cuda.synchronize() | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| times = [] | ||
| for _ in range(iters): | ||
| torch.cuda.synchronize() | ||
| t0 = time.perf_counter() | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| times.append((time.perf_counter() - t0) * 1e3) | ||
|
|
||
| times.sort() | ||
| n = len(times) | ||
| return times[n // 2] # median | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Benchmark configuration | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| @dataclass | ||
| class BenchConfig: | ||
| N: int | ||
| dtype: str # "bf16" | "fp16" | "fp32" | ||
| quantize: bool | ||
| decoupled: bool | ||
|
|
||
|
|
||
| # All combinations to benchmark | ||
| CONFIGS = [ | ||
| BenchConfig(N=n, dtype=dt, quantize=q, decoupled=d) | ||
| for n in [4_096, 65_536, 1_048_576, 16_777_216] # 4K → 16M elements | ||
| for dt in ["bf16", "fp16"] | ||
| for q in [True, False] | ||
| for d in [True, False] | ||
| ] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Main | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="FlashAdam CUDA vs Triton benchmark") | ||
| parser.add_argument("--warmup", type=int, default=30, help="Warmup iterations") | ||
| parser.add_argument("--iters", type=int, default=100, help="Timed iterations") | ||
| parser.add_argument("--csv", type=str, default=None, help="Save results to CSV") | ||
| parser.add_argument("--dtype", type=str, default=None, | ||
| help="Filter dtype (bf16|fp16|fp32)") | ||
| parser.add_argument("--n", type=int, default=None, | ||
| help="Filter N (exact match)") | ||
| args = parser.parse_args() | ||
|
|
||
| if not torch.cuda.is_available(): | ||
| print("ERROR: No CUDA GPU available.", file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| # Check extension | ||
| try: | ||
| import flashoptim._cuda_adam # noqa: F401 | ||
| except ImportError: | ||
| print("ERROR: flashoptim._cuda_adam not compiled. Run:\n" | ||
| " FLASHOPTIM_BUILD_CUDA=1 python setup.py build_ext --inplace", | ||
| file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| import flashoptim.optimizers as opt_mod | ||
| opt_mod._try_load_cuda_adam_ext() | ||
|
|
||
| # Pre-warm Triton JIT for all (dtype, quantize, decoupled) combos | ||
| # Use N=65536 to ensure the exact same Triton kernel (tiled for that size) is compiled. | ||
| device = "cuda" | ||
| print("Pre-warming Triton JIT...", end=" ", flush=True) | ||
| seen = set() | ||
| for cfg in CONFIGS: | ||
| key = (cfg.dtype, cfg.quantize, cfg.decoupled) | ||
| if key in seen: | ||
| continue | ||
| seen.add(key) | ||
| dtype = _DTYPE_MAP[cfg.dtype] | ||
| p, g, m, ms, v, vs = _make_state(65536, device, cfg.quantize, dtype) | ||
| for _ in range(3): | ||
| _run_triton(p.clone(), g, m.clone(), ms.clone(), v.clone(), vs.clone(), | ||
| cfg.quantize, cfg.decoupled, 1) | ||
| torch.cuda.synchronize() | ||
| print("done") | ||
| gpu_name = torch.cuda.get_device_name(0) | ||
| p = torch.cuda.get_device_properties(0) | ||
| print(f"\n{'='*72}") | ||
| print(f" GPU : {gpu_name} (SM {p.major}.{p.minor})") | ||
| print(f" Warmup: {args.warmup} Timed: {args.iters}") | ||
| print(f"{'='*72}") | ||
| print(f"{'N':>12} {'dtype':>5} {'quant':>5} {'decoup':>6} " | ||
| f"{'Triton(ms)':>10} {'CUDA(ms)':>10} {'Speedup':>7} " | ||
| f"{'BW Triton':>10} {'BW CUDA':>10}") | ||
| print(f"{'-'*92}") | ||
|
|
||
| configs = CONFIGS | ||
| if args.dtype: | ||
| configs = [c for c in configs if c.dtype == args.dtype] | ||
| if args.n: | ||
| configs = [c for c in configs if c.N == args.n] | ||
|
|
||
| rows = [] | ||
| for cfg in configs: | ||
| dtype = _DTYPE_MAP[cfg.dtype] | ||
| param, grad, mom, ms, var, vs = _make_state(cfg.N, device, cfg.quantize, dtype) | ||
|
|
||
| # Clones for each backend so state doesn't accumulate differences | ||
| def triton_step(step=[1]): | ||
| _run_triton(param.clone(), grad, mom.clone(), ms.clone(), | ||
| var.clone(), vs.clone(), cfg.quantize, cfg.decoupled, step[0]) | ||
| step[0] += 1 | ||
|
|
||
| def cuda_step(step=[1]): | ||
| _run_cuda(param.clone(), grad, mom.clone(), ms.clone(), | ||
| var.clone(), vs.clone(), cfg.quantize, cfg.decoupled, step[0]) | ||
| step[0] += 1 | ||
|
|
||
| t_triton = _bench(triton_step, args.warmup, args.iters) | ||
| t_cuda = _bench(cuda_step, args.warmup, args.iters) | ||
| speedup = t_triton / t_cuda if t_cuda > 0 else float("inf") | ||
|
|
||
| # Approximate memory bandwidth (bytes read + written per step) | ||
| elem_bytes = 2 if cfg.dtype in ("bf16", "fp16") else 4 | ||
| if cfg.quantize: | ||
| # mom(i8) + var(u8) + mom_scales(f16) + var_scales(f16) + param + grad | ||
| G = (cfg.N + 31) // 32 | ||
| bw_elems = 2 * cfg.N + 2 * G * 2 + 2 * cfg.N * elem_bytes | ||
| else: | ||
| bw_elems = 4 * cfg.N * elem_bytes # mom + var + param + grad | ||
| bw_gb_triton = bw_elems / (t_triton * 1e-3) / 1e9 | ||
| bw_gb_cuda = bw_elems / (t_cuda * 1e-3) / 1e9 | ||
|
|
||
| print(f"{cfg.N:>12,} {cfg.dtype:>5} {str(cfg.quantize):>5} " | ||
| f"{str(cfg.decoupled):>6} " | ||
| f"{t_triton:>10.3f} {t_cuda:>10.3f} {speedup:>7.2f}x " | ||
| f"{bw_gb_triton:>9.1f}G {bw_gb_cuda:>9.1f}G") | ||
|
|
||
| rows.append({ | ||
| "gpu": gpu_name, | ||
| "N": cfg.N, | ||
| "dtype": cfg.dtype, | ||
| "quantize": cfg.quantize, | ||
| "decoupled": cfg.decoupled, | ||
| "triton_ms": round(t_triton, 4), | ||
| "cuda_ms": round(t_cuda, 4), | ||
| "speedup": round(speedup, 4), | ||
| "bw_triton_GBs": round(bw_gb_triton, 2), | ||
| "bw_cuda_GBs": round(bw_gb_cuda, 2), | ||
| }) | ||
|
|
||
| print(f"{'='*92}\n") | ||
|
|
||
| if args.csv: | ||
| with open(args.csv, "w", newline="") as f: | ||
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) | ||
| writer.writeheader() | ||
| writer.writerows(rows) | ||
| print(f"Results saved to {args.csv}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are unused imports (
fieldfrom dataclasses andOptionalfrom typing). These will fail ruff/pyflakes checks if linting is enabled. Remove unused imports.