Skip to content

Commit 91a004b

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Test case for EBC key-order change (#2388)
Summary: Pull Request resolved: #2388 # context * [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * this test case mimics the EBC key-order change after sharding {F1864056306} # details * it's a very simple model: EBC ---> KTRegroupAsDict * we generate two EBCs: ebc1 and ebc2, such that the table orders are different: ``` ebc1 = EmbeddingBagCollection( tables=[tb1_config, tb2_config, tb3_config], is_weighted=False, ) ebc2 = EmbeddingBagCollection( tables=[tb1_config, tb3_config, tb2_config], is_weighted=False, ) ``` * we export the model with ebc1 and unflatten the model, and then swap with ebc2 (you can think this as the the sharding process resulting a shardedEBC), so that we can mimic the key-order change as shown in the above graph * the test checks the final results after KTRegroupAsDict are consistent with the original eager model Reviewed By: PaulZhang12 Differential Revision: D62604419
1 parent ab2a161 commit 91a004b

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

Diff for: torchrec/ir/tests/test_serializer.py

+70
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,73 @@ def forward(
521521
self.assertTrue(deserialized_model.regroup._is_inited)
522522
for key in eager_out.keys():
523523
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)
524+
525+
def test_key_order_with_ebc_and_regroup(self) -> None:
526+
tb1_config = EmbeddingBagConfig(
527+
name="t1",
528+
embedding_dim=3,
529+
num_embeddings=10,
530+
feature_names=["f1"],
531+
)
532+
tb2_config = EmbeddingBagConfig(
533+
name="t2",
534+
embedding_dim=4,
535+
num_embeddings=10,
536+
feature_names=["f2"],
537+
)
538+
tb3_config = EmbeddingBagConfig(
539+
name="t3",
540+
embedding_dim=5,
541+
num_embeddings=10,
542+
feature_names=["f3"],
543+
)
544+
id_list_features = KeyedJaggedTensor.from_offsets_sync(
545+
keys=["f1", "f2", "f3", "f4", "f5"],
546+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
547+
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
548+
)
549+
ebc1 = EmbeddingBagCollection(
550+
tables=[tb1_config, tb2_config, tb3_config],
551+
is_weighted=False,
552+
)
553+
ebc2 = EmbeddingBagCollection(
554+
tables=[tb1_config, tb3_config, tb2_config],
555+
is_weighted=False,
556+
)
557+
ebc2.load_state_dict(ebc1.state_dict())
558+
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])
559+
560+
class myModel(nn.Module):
561+
def __init__(self, ebc, regroup):
562+
super().__init__()
563+
self.ebc = ebc
564+
self.regroup = regroup
565+
566+
def forward(
567+
self,
568+
features: KeyedJaggedTensor,
569+
) -> Dict[str, torch.Tensor]:
570+
return self.regroup([self.ebc(features)])
571+
572+
model = myModel(ebc1, regroup)
573+
eager_out = model(id_list_features)
574+
575+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
576+
ep = torch.export.export(
577+
model,
578+
(id_list_features,),
579+
{},
580+
strict=False,
581+
# Allows KJT to not be unflattened and run a forward on unflattened EP
582+
preserve_module_call_signature=(tuple(sparse_fqns)),
583+
)
584+
unflatten_ep = torch.export.unflatten(ep)
585+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
586+
# we export the model with ebc1 and unflatten the model,
587+
# and then swap with ebc2 (you can think this as the the sharding process
588+
# resulting a shardedEBC), so that we can mimic the key-order change
589+
deserialized_model.ebc = ebc2
590+
591+
deserialized_out = deserialized_model(id_list_features)
592+
for key in eager_out.keys():
593+
torch.testing.assert_close(deserialized_out[key], eager_out[key])

0 commit comments

Comments
 (0)