Skip to content

Commit 2c684a8

Browse files
committed
check in __init__ whether the module is already FSDP; fix exception types
1 parent 9a87bca commit 2c684a8

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def __init__(
142142
use_all_gather_via_all_reduce: bool = True,
143143
mark_step_on_freeing: bool = False,
144144
):
145+
if isinstance(module, XlaFullyShardedDataParallel):
146+
raise RuntimeError(
147+
"Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, "
148+
"first wrap the inner child modules before wrapping the outer parent module."
149+
)
145150
is_forward_defined = (
146151
hasattr(module, "forward") and hasattr(module.forward, "__func__") and
147152
module.forward.__func__ != torch.nn.Module.forward)
@@ -383,10 +388,10 @@ def _shard_parameters_(self, params_to_shard) -> None:
383388
for module_name, m in self.named_modules():
384389
for n, p in m.named_parameters(recurse=False):
385390
if "xla" not in str(p.device):
386-
raise Exception(
391+
raise ValueError(
387392
"please moved the module to XLA device before wrapping with FSDP")
388393
if p.dtype != torch.float32:
389-
raise Exception("only fp32 parameters are supported")
394+
raise TypeError("only fp32 parameters are supported")
390395
if p in params_to_shard_set:
391396
if p in shared_full_param_memo:
392397
mname, shared_m, shared_n = shared_full_param_memo[p]

0 commit comments

Comments
 (0)