@@ -223,6 +223,16 @@ class TestLossModuleBase:
223
223
"sample_log_prob_key" : "sample_log_prob" ,
224
224
"state_action_value_key" : "state_action_value" ,
225
225
},
226
+ DreamerModelLoss : {
227
+ "reward_key" : "reward" ,
228
+ "true_reward_key" : "true_reward" ,
229
+ "prior_mean_key" : "prior_mean" ,
230
+ "prior_std_key" : "prior_std" ,
231
+ "posterior_mean_key" : "posterior_mean" ,
232
+ "posterior_std_key" : "posterior_std" ,
233
+ "pixels_key" : "pixels" ,
234
+ "reco_pixels_key" : "reco_pixels" ,
235
+ },
226
236
DreamerActorLoss : {
227
237
"belief_key" : "belief" ,
228
238
"reward_key" : "reward" ,
@@ -437,6 +447,81 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
437
447
value_model (td )
438
448
return value_model
439
449
450
+ def _create_world_model_model (self , rssm_hidden_dim , state_dim , mlp_num_units = 200 ):
451
+ mock_env = TransformedEnv (ContinuousActionConvMockEnv (pixel_shape = [3 , 64 , 64 ]))
452
+ default_dict = {
453
+ "state" : UnboundedContinuousTensorSpec (state_dim ),
454
+ "belief" : UnboundedContinuousTensorSpec (rssm_hidden_dim ),
455
+ }
456
+ mock_env .append_transform (
457
+ TensorDictPrimer (random = False , default_value = 0 , ** default_dict )
458
+ )
459
+
460
+ obs_encoder = ObsEncoder ()
461
+ obs_decoder = ObsDecoder ()
462
+
463
+ rssm_prior = RSSMPrior (
464
+ hidden_dim = rssm_hidden_dim ,
465
+ rnn_hidden_dim = rssm_hidden_dim ,
466
+ state_dim = state_dim ,
467
+ action_spec = mock_env .action_spec ,
468
+ )
469
+ rssm_posterior = RSSMPosterior (hidden_dim = rssm_hidden_dim , state_dim = state_dim )
470
+
471
+ # World Model and reward model
472
+ rssm_rollout = RSSMRollout (
473
+ SafeModule (
474
+ rssm_prior ,
475
+ in_keys = ["state" , "belief" , "action" ],
476
+ out_keys = [
477
+ ("next" , "prior_mean" ),
478
+ ("next" , "prior_std" ),
479
+ "_" ,
480
+ ("next" , "belief" ),
481
+ ],
482
+ ),
483
+ SafeModule (
484
+ rssm_posterior ,
485
+ in_keys = [("next" , "belief" ), ("next" , "encoded_latents" )],
486
+ out_keys = [
487
+ ("next" , "posterior_mean" ),
488
+ ("next" , "posterior_std" ),
489
+ ("next" , "state" ),
490
+ ],
491
+ ),
492
+ )
493
+ reward_module = MLP (
494
+ out_features = 1 , depth = 2 , num_cells = mlp_num_units , activation_class = nn .ELU
495
+ )
496
+ # World Model and reward model
497
+ world_modeler = SafeSequential (
498
+ SafeModule (
499
+ obs_encoder ,
500
+ in_keys = [("next" , "pixels" )],
501
+ out_keys = [("next" , "encoded_latents" )],
502
+ ),
503
+ rssm_rollout ,
504
+ SafeModule (
505
+ obs_decoder ,
506
+ in_keys = [("next" , "state" ), ("next" , "belief" )],
507
+ out_keys = [("next" , "reco_pixels" )],
508
+ ),
509
+ )
510
+ reward_module = SafeModule (
511
+ reward_module ,
512
+ in_keys = [("next" , "state" ), ("next" , "belief" )],
513
+ out_keys = [("next" , "reward" )],
514
+ )
515
+ world_model = WorldModelWrapper (world_modeler , reward_module )
516
+
517
+ with torch .no_grad ():
518
+ td = mock_env .rollout (10 )
519
+ td = td .unsqueeze (0 ).to_tensordict ()
520
+ td ["state" ] = torch .zeros ((1 , 10 , state_dim ))
521
+ td ["belief" ] = torch .zeros ((1 , 10 , rssm_hidden_dim ))
522
+ world_model (td )
523
+ return world_model
524
+
440
525
def _construct_loss (self , loss_module , ** kwargs ):
441
526
print (f"{ loss_module = } " )
442
527
if loss_module in [
@@ -466,6 +551,9 @@ def _construct_loss(self, loss_module, **kwargs):
466
551
actor = self ._create_mock_actor (action_spec_type = "one_hot" )
467
552
qvalue = self ._create_mock_qvalue ()
468
553
return loss_module (actor , qvalue , actor .spec ["action" ].space .n , ** kwargs )
554
+ elif loss_module in [DreamerModelLoss ]:
555
+ world_model = self ._create_world_model_model (10 , 5 )
556
+ return DreamerModelLoss (world_model )
469
557
elif loss_module in [DreamerActorLoss ]:
470
558
mb_env = self ._create_mb_env (10 , 5 )
471
559
actor_model = self ._create_actor_model (10 , 5 )
@@ -497,6 +585,7 @@ def _construct_loss(self, loss_module, **kwargs):
497
585
498
586
@pytest .mark .parametrize ("loss_module" , LOSS_MODULES )
499
587
def test_tensordict_keys_unknown_key (self , loss_module ):
588
+ """Test that exception is raised if an unknown key is set via .set_keys()"""
500
589
loss_fn = self ._construct_loss (loss_module )
501
590
502
591
with pytest .raises (ValueError ):
@@ -512,6 +601,7 @@ def test_tensordict_keys_default_values(self, loss_module):
512
601
513
602
@pytest .mark .parametrize ("loss_module" , LOSS_MODULES )
514
603
def test_tensordict_set_keys (self , loss_module ):
604
+ """Test setting of tensordict keys via .set_keys()"""
515
605
default_keys = self .DEFAULT_KEYS [loss_module ]
516
606
517
607
loss_fn = self ._construct_loss (loss_module )
@@ -529,6 +619,7 @@ def test_tensordict_set_keys(self, loss_module):
529
619
530
620
@pytest .mark .parametrize ("loss_module" , LOSS_MODULES )
531
621
def test_tensordict_deprecated_ctor (self , loss_module ):
622
+ """Test that a warning is raised if a deprecated tensordict key is set via the ctor."""
532
623
try :
533
624
dep_keys = self .DEPRECATED_CTOR_KEYS [loss_module ]
534
625
except KeyError :
@@ -546,6 +637,15 @@ def test_tensordict_deprecated_ctor(self, loss_module):
546
637
if def_key != key :
547
638
assert getattr (loss_fn , def_key ) == def_value
548
639
640
+ @pytest .mark .parametrize ("loss_module" , LOSS_MODULES )
641
+ def test_tensordict_all_keys_tested (self , loss_module ):
642
+ """Check that DEFAULT_KEYS contains all available tensordict keys from each loss module."""
643
+ tested_keys = set (self .DEFAULT_KEYS [loss_module ].keys ())
644
+
645
+ loss_fn = self ._construct_loss (loss_module )
646
+ avail_keys = set (loss_fn .tensordict_keys .keys ())
647
+ assert avail_keys .difference (tested_keys ) == set ()
648
+
549
649
550
650
class TestDQN :
551
651
seed = 0
0 commit comments