Fix FlashOptim FSDP2 checkpoint serialization#5
Fix FlashOptim FSDP2 checkpoint serialization#5kashif wants to merge 8 commits intodatabricks:mainfrom
Conversation
|
Thanks for the contribution @kashif can you elaborate what
means? |
…nt-fix # Conflicts: # flashoptim/optimizers.py
v0.1.4's 'continue when param has no state' also handles pre-first-step state_dict() calls, so the manual state_dict() rebuild and FSDP2 token matching we added are no longer load-bearing. Drop them; keep only the uneven-shard silent bail and the CPU-shard diagnostic, which main does not provide.
after rebasing onto v0.1.4, this fix doesnt apply anymore. v0.1.4's if param_number not in opt_state:
continuein The original motivation for calling it out was that So now, I've dropped the redundant code and kept only a regression test ( |
Pass shape and stride explicitly to DTensor.from_local for state tensors whose local shape matches the param's local shape. This is the stable public API for constructing a DTensor with a global shape that doesn't divide evenly — the previous default inference (local_size * world_size) is only correct for even splits. Round-trips correctly through both FULL_STATE_DICT and SHARDED_STATE_DICT paths for even and uneven shards. Replaces the earlier silent bail-out, which fixed FULL_STATE_DICT export but would have silently scrambled SHARDED_STATE_DICT checkpoints with uneven shards. Quantized state (different local shape from param) keeps the previous default-inference behavior; that path is out of scope here.
Exercises the proper _wrap_state_as_dtensor fix: a model whose param shapes don't divide by world_size=2 is sharded via fully_shard, trained briefly, then its optimizer state is saved and restored through DCP. Verifies (a) wrapped DTensors carry the true global shape rather than local_size * world_size, and (b) DCP save/load round-trips bit-exactly.
For state tensors whose shape matches the param (uncompressed momentum /variance), keep the explicit shape/stride wrapping so DCP round-trips on uneven shards. For tensors with a different layout (e.g. quantized state on uneven shards), re-raise as in pre-PR behavior — default inference would silently scramble DCP loads.
Fixes two small FSDP2/Accelerate interop issues in FlashOptim when used through HF Trainer / TRL / Accelerate. Since v0.1.4 landed on main (which added the frozen-param
continueinstate_dict()), the pre-first-step case is already handled upstream — this PR now only carries the remaining gaps.Summary
_wrap_state_as_dtensorpreviously raisedValueErroron unevenly sharded parameters. In FSDP2FULL_STATE_DICTexport (accelerate/utils/fsdp_utils.py), the caller only needs a gathered state dict — not DCP-ready DTensors — so the raise aborted valid exports. Now the function silently returns, leaving tensors unwrapped (still a correct gathered export).fsdp_offload_params: true) crashed with an opaque kernel error duringstep(). Now_ensure_state_initializedraises a targetedValueErrorpointing atfsdp_offload_params: false.prepare(). Pins thatFlashOptimizer.state_dict()on a never-stepped optimizer returns{"state": {}, "param_groups": [...]}matching vanilla AdamW. Accelerate'sAcceleratedOptimizer.__init__callsoptimizer.state_dict()duringaccelerator.prepare()(to move state to the correct device) — this test guards against regressing that contract.Why
Single-GPU training worked, but FSDP2 checkpointing through Trainer/TRL/ Accelerate exposed two serialization assumptions that don't hold in distributed preparation: (a) unevenly sharded parameters can make DTensor wrapping impossible even when a valid full-state export is still possible, and (b) CPU-offloaded shards are not a meaningful FlashOptim mode and previously failed with a low-level quantization error instead of a clear incompatibility message.
Caveats