@@ -142,6 +142,11 @@ def __init__(
142
142
use_all_gather_via_all_reduce : bool = True ,
143
143
mark_step_on_freeing : bool = False ,
144
144
):
145
+ if isinstance (module , XlaFullyShardedDataParallel ):
146
+ raise RuntimeError (
147
+ "Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, "
148
+ "first wrap the inner child modules before wrapping the outer parent module."
149
+ )
145
150
is_forward_defined = (
146
151
hasattr (module , "forward" ) and hasattr (module .forward , "__func__" ) and
147
152
module .forward .__func__ != torch .nn .Module .forward )
@@ -383,10 +388,10 @@ def _shard_parameters_(self, params_to_shard) -> None:
383
388
for module_name , m in self .named_modules ():
384
389
for n , p in m .named_parameters (recurse = False ):
385
390
if "xla" not in str (p .device ):
386
- raise Exception (
391
+ raise ValueError (
387
392
"please moved the module to XLA device before wrapping with FSDP" )
388
393
if p .dtype != torch .float32 :
389
- raise Exception ("only fp32 parameters are supported" )
394
+ raise TypeError ("only fp32 parameters are supported" )
390
395
if p in params_to_shard_set :
391
396
if p in shared_full_param_memo :
392
397
mname , shared_m , shared_n = shared_full_param_memo [p ]
0 commit comments