Skip to content

Fix FlashOptim FSDP2 checkpoint serialization#5

Open
kashif wants to merge 8 commits intodatabricks:mainfrom
kashif:kashif/fsdp2-checkpoint-fix
Open

Fix FlashOptim FSDP2 checkpoint serialization#5
kashif wants to merge 8 commits intodatabricks:mainfrom
kashif:kashif/fsdp2-checkpoint-fix

Conversation

@kashif
Copy link
Copy Markdown

@kashif kashif commented Mar 31, 2026

Fixes two small FSDP2/Accelerate interop issues in FlashOptim when used through HF Trainer / TRL / Accelerate. Since v0.1.4 landed on main (which added the frozen-param continue in state_dict()), the pre-first-step case is already handled upstream — this PR now only carries the remaining gaps.

Summary

  • Uneven-shard fallback during full-state-dict export. _wrap_state_as_dtensor previously raised ValueError on unevenly sharded parameters. In FSDP2 FULL_STATE_DICT export (accelerate/utils/fsdp_utils.py), the caller only needs a gathered state dict — not DCP-ready DTensors — so the raise aborted valid exports. Now the function silently returns, leaving tensors unwrapped (still a correct gathered export).
  • Clearer error for CPU-offloaded shards. FlashOptim's quantized state and fused CUDA kernels require CUDA-resident shards at step time. Previously, a CPU-offloaded shard (e.g. fsdp_offload_params: true) crashed with an opaque kernel error during step(). Now _ensure_state_initialized raises a targeted ValueError pointing at fsdp_offload_params: false.
  • Regression test for Accelerate prepare(). Pins that FlashOptimizer.state_dict() on a never-stepped optimizer returns {"state": {}, "param_groups": [...]} matching vanilla AdamW. Accelerate's AcceleratedOptimizer.__init__ calls optimizer.state_dict() during accelerator.prepare() (to move state to the correct device) — this test guards against regressing that contract.

Why

Single-GPU training worked, but FSDP2 checkpointing through Trainer/TRL/ Accelerate exposed two serialization assumptions that don't hold in distributed preparation: (a) unevenly sharded parameters can make DTensor wrapping impossible even when a valid full-state export is still possible, and (b) CPU-offloaded shards are not a meaningful FlashOptim mode and previously failed with a low-level quantization error instead of a clear incompatibility message.

Caveats

  • Uneven-shard handling now favors successful full-state-compatible export over strict DTensor wrapping for that case.
  • CPU-offloaded parameter shards remain unsupported; this PR only improves the error message.

@kashif kashif requested a review from josejg as a code owner March 31, 2026 15:57
@josejg
Copy link
Copy Markdown
Collaborator

josejg commented Apr 17, 2026

Thanks for the contribution @kashif

can you elaborate what

  • make FlashOptimizer.state_dict() behave like a standard PyTorch optimizer before the first step, since accelerate may inspect optimizer state during prepare()

means?

kashif added 2 commits April 18, 2026 07:39
…nt-fix

# Conflicts:
#	flashoptim/optimizers.py
v0.1.4's 'continue when param has no state' also handles pre-first-step
state_dict() calls, so the manual state_dict() rebuild and FSDP2 token
matching we added are no longer load-bearing. Drop them; keep only the
uneven-shard silent bail and the CPU-shard diagnostic, which main does
not provide.
@kashif
Copy link
Copy Markdown
Author

kashif commented Apr 18, 2026

Thanks for the contribution @kashif

can you elaborate what

  • make FlashOptimizer.state_dict() behave like a standard PyTorch optimizer before the first step, since accelerate may inspect optimizer state during prepare()

means?

after rebasing onto v0.1.4, this fix doesnt apply anymore. v0.1.4's

if param_number not in opt_state:
    continue

in state_dict() already handles the pre-first-step case (same condition as frozen params: no state allocated → skip).

The original motivation for calling it out was that accelerate.AcceleratedOptimizer.__init__ calls optimizer.state_dict() during accelerator.prepare() to move state to the target device when device_placement=True — if FlashOptim had raised or returned garbage on empty state, Trainer + Accelerate would break before step 1.

So now, I've dropped the redundant code and kept only a regression test (test_state_dict_before_first_step_matches_torch_empty_state) that pins the contract, so future refactors can't regress it. I've also updated the PR description to reflect the narrower scope.

kashif added 4 commits April 18, 2026 08:14
Pass shape and stride explicitly to DTensor.from_local for state tensors
whose local shape matches the param's local shape. This is the stable
public API for constructing a DTensor with a global shape that doesn't
divide evenly — the previous default inference (local_size * world_size)
is only correct for even splits.

Round-trips correctly through both FULL_STATE_DICT and SHARDED_STATE_DICT
paths for even and uneven shards. Replaces the earlier silent bail-out,
which fixed FULL_STATE_DICT export but would have silently scrambled
SHARDED_STATE_DICT checkpoints with uneven shards.

Quantized state (different local shape from param) keeps the previous
default-inference behavior; that path is out of scope here.
Exercises the proper _wrap_state_as_dtensor fix: a model whose param
shapes don't divide by world_size=2 is sharded via fully_shard, trained
briefly, then its optimizer state is saved and restored through DCP.
Verifies (a) wrapped DTensors carry the true global shape rather than
local_size * world_size, and (b) DCP save/load round-trips bit-exactly.
For state tensors whose shape matches the param (uncompressed momentum
/variance), keep the explicit shape/stride wrapping so DCP round-trips
on uneven shards. For tensors with a different layout (e.g. quantized
state on uneven shards), re-raise as in pre-PR behavior — default
inference would silently scramble DCP loads.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants