Skip to content

Commit 0fe3a05

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Expose SE/DESE support to EmbeddingRocksDBWrapper for training pipeline (#4227)
Summary: X-link: facebookresearch/FBGEMM#1304 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: 1. We expose serialization and deserialization support through EmbeddingRocksDBWrapper 2. Added a function to generate logs when reading through EmbeddingRocksDB/ReadOnlyEmbeddingKVDB Reviewed By: duduyi2013 Differential Revision: D75489895
1 parent 5c6ac8f commit 0fe3a05

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,17 @@ static auto kv_tensor_wrapper =
886886
&KVTensorWrapper::sizes,
887887
std::string(
888888
"Returns the shape of the original tensor. Only the narrowed part is materialized."))
889-
.def_property("strides", &KVTensorWrapper::strides);
889+
.def_property("strides", &KVTensorWrapper::strides)
890+
.def_pickle(
891+
// __getstate__
892+
[](const c10::intrusive_ptr<KVTensorWrapper>& self) -> std::string {
893+
return self->serialize();
894+
},
895+
// __setstate__
896+
[](std::string data) -> c10::intrusive_ptr<KVTensorWrapper> {
897+
return c10::make_intrusive<KVTensorWrapper>(data);
898+
})
899+
.def("logs", &KVTensorWrapper::logs, "");
890900

891901
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
892902
m.def(

0 commit comments

Comments
 (0)