Skip to content

Commit c358b86

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Creating ReadOnlyEmbeddingKVDB class and necessary functions (#4225)
Summary: Pull Request resolved: #4225 X-link: facebookresearch/FBGEMM#1301 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: The primary objective of adding the checkpointhandle is to allow multiple process read through the KVTensor. To enable this, we would require to create a read-only KVTensor object that can be read concurrently. To support this, we introduce an ReadOnlyEmbedding KVDB class which is a read-only implementation of EmbeddingKVDB class. We have added a new constructor to the KVTensorWrapper which takes in a serialized KVTensor meta data. When deserializing, we create a readOnlyEmbeddingKVDB for the KVTensorWrapper object Reviewed By: duduyi2013 Differential Revision: D75489873
1 parent 14cef3f commit c358b86

File tree

3 files changed

+475
-2
lines changed

3 files changed

+475
-2
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
6666
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> checkpoint_handle =
6767
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>(nullptr));
6868

69+
explicit KVTensorWrapper(const std::string& serialized);
70+
6971
at::Tensor narrow(int64_t dim, int64_t start, int64_t length);
7072

7173
/// @brief if the backend storage is SSD, use this function
@@ -108,6 +110,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
108110

109111
std::string layout_str();
110112

113+
std::string serialize() const;
114+
115+
// ONLY FOR DEBUGGING PURPOSES, Please don't use this function in production
116+
std::string logs() const;
117+
118+
void deserialize(const std::string& serialized);
119+
120+
friend void to_json(json& j, const KVTensorWrapper& kvt);
121+
friend void from_json(const json& j, KVTensorWrapper& kvt);
122+
111123
private:
112124
std::shared_ptr<kv_db::EmbeddingKVDB> db_;
113125
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
@@ -119,6 +131,26 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
119131
int64_t width_offset_;
120132
std::mutex mtx;
121133
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> checkpoint_handle_;
134+
// Used for initializting a readonly rocksdb instance, that we will used for
135+
// cross process async read
136+
std::shared_ptr<ReadOnlyEmbeddingKVDB> readonly_db_;
137+
// below are variables that is used to hold ReadOnlyEmbeddingKVDB constructor
138+
// arguments, they will be filled up when serialize happens and will be used
139+
// to construct ReadOnlyEmbeddingKVDB instance later after deserialization
140+
//
141+
// we don't do ReadOnlyEmbeddingKVDB construction upon KVTensorWrapper
142+
// construction, because one ReadOnlyEmbeddingKVDB(rdb checkpoint) could store
143+
// table shards for multiple tables, they should share the same underlying
144+
// ReadOnlyEmbeddingKVDB instance to easily manage rdb checkpoint lifetime.
145+
std::vector<std::string> rdb_shard_checkpoint_paths;
146+
std::string tbe_uuid;
147+
int64_t num_shards{};
148+
int64_t num_threads{};
149+
int64_t max_D{};
150+
std::string checkpoint_uuid;
122151
};
123152

153+
void to_json(json& j, const KVTensorWrapper& kvt);
154+
void from_json(const json& j, KVTensorWrapper& kvt);
155+
124156
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <c10/core/ScalarTypeToTypeMeta.h>
1212
#include <torch/library.h>
1313

14+
#include <nlohmann/json.hpp>
1415
#include <torch/custom_class.h>
1516
#include <mutex>
1617
#include "../dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h"
@@ -324,6 +325,10 @@ CheckpointHandle::CheckpointHandle(
324325
}
325326
}
326327

328+
std::vector<std::string> CheckpointHandle::get_shard_checkpoints() const {
329+
return shard_checkpoints_;
330+
}
331+
327332
EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
328333
const SnapshotHandle* handle,
329334
std::shared_ptr<EmbeddingRocksDB> db)
@@ -377,6 +382,64 @@ KVTensorWrapper::KVTensorWrapper(
377382
checkpoint_handle_ = checkpoint_handle;
378383
}
379384

385+
std::string KVTensorWrapper::serialize() const {
386+
// auto call to_json()
387+
ssd::json json_serialized = *this;
388+
return json_serialized.dump();
389+
}
390+
391+
std::string KVTensorWrapper::logs() const {
392+
std::stringstream ss;
393+
if (db_) {
394+
CHECK(readonly_db_ == nullptr) << "rdb logs, ro_rdb must be nullptr";
395+
ss << "from ckpt paths: " << std::endl;
396+
// Required to cast as the KVTensorWrapper.db_ is a pointer for the
397+
// EmbeddingKVDB class which is inherited by the EmbeddingRocksDB class
398+
auto* db = dynamic_cast<EmbeddingRocksDB*>(db_.get());
399+
auto ckpts = db->get_checkpoints(checkpoint_handle_->uuid);
400+
for (int i = 0; i < ckpts.size(); i++) {
401+
ss << " shard:" << i << ", ckpt_path:" << ckpts[i] << std::endl;
402+
}
403+
ss << " tbe_uuid: " << db->get_tbe_uuid() << std::endl;
404+
ss << " num_shards: " << db->num_shards() << std::endl;
405+
ss << " num_threads: " << db->num_threads() << std::endl;
406+
ss << " max_D: " << db->get_max_D() << std::endl;
407+
ss << " row_offset: " << row_offset_ << std::endl;
408+
ss << " shape: " << shape_ << std::endl;
409+
ss << " dtype: " << static_cast<int64_t>(options_.dtype().toScalarType())
410+
<< std::endl;
411+
ss << " checkpoint_uuid: " << checkpoint_handle_->uuid << std::endl;
412+
} else {
413+
CHECK(readonly_db_) << "ro_rdb logs, ro_rdb must be valid";
414+
ss << "from ckpt paths: " << std::endl;
415+
auto* db = dynamic_cast<ReadOnlyEmbeddingKVDB*>(readonly_db_.get());
416+
auto rdb_shard_checkpoint_paths = db->get_rdb_shard_checkpoint_paths();
417+
for (int i = 0; i < rdb_shard_checkpoint_paths.size(); i++) {
418+
ss << " shard:" << i << ", ckpt_path:" << rdb_shard_checkpoint_paths[i]
419+
<< std::endl;
420+
}
421+
ss << " tbe_uuid: " << db->get_tbe_uuid() << std::endl;
422+
ss << " num_shards: " << db->num_shards() << std::endl;
423+
ss << " num_threads: " << db->num_threads() << std::endl;
424+
ss << " max_D: " << db->get_max_D() << std::endl;
425+
ss << " row_offset: " << row_offset_ << std::endl;
426+
ss << " shape: " << shape_ << std::endl;
427+
ss << " dtype: " << static_cast<int64_t>(options_.dtype().toScalarType())
428+
<< std::endl;
429+
ss << " checkpoint_uuid: " << checkpoint_uuid << std::endl;
430+
}
431+
return ss.str();
432+
}
433+
434+
void KVTensorWrapper::deserialize(const std::string& serialized) {
435+
ssd::json json_serialized = ssd::json::parse(serialized);
436+
from_json(json_serialized, *this);
437+
}
438+
439+
KVTensorWrapper::KVTensorWrapper(const std::string& serialized) {
440+
deserialize(serialized);
441+
}
442+
380443
void KVTensorWrapper::set_embedding_rocks_dp_wrapper(
381444
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db) {
382445
db_ = db->impl_;
@@ -454,6 +517,55 @@ void KVTensorWrapper::set_weights_and_ids(
454517
}
455518
}
456519

520+
void to_json(ssd::json& j, const KVTensorWrapper& kvt) {
521+
// Required to cast as the KVTensorWrapper.db_ is a pointer for the
522+
// EmbeddingKVDB class which is inherited by the EmbeddingRocksDB class
523+
std::shared_ptr<EmbeddingRocksDB> db =
524+
std::dynamic_pointer_cast<EmbeddingRocksDB>(kvt.db_);
525+
j = ssd::json{
526+
{"rdb_shard_checkpoint_paths",
527+
db->get_checkpoints(kvt.checkpoint_handle_->uuid)},
528+
{"tbe_uuid", db->get_tbe_uuid()},
529+
{"num_shards", db->num_shards()},
530+
{"num_threads", db->num_threads()},
531+
{"max_D", db->get_max_D()},
532+
{"row_offset", kvt.row_offset_},
533+
{"shape", kvt.shape_},
534+
{"dtype", static_cast<int64_t>(kvt.options_.dtype().toScalarType())},
535+
{"checkpoint_uuid", kvt.checkpoint_handle_->uuid}};
536+
}
537+
538+
void from_json(const ssd::json& j, KVTensorWrapper& kvt) {
539+
std::vector<std::string> rdb_shard_checkpoint_paths;
540+
std::string tbe_uuid;
541+
int64_t num_shards;
542+
int64_t num_threads;
543+
int64_t max_D;
544+
int64_t dtype;
545+
j.at("rdb_shard_checkpoint_paths").get_to(rdb_shard_checkpoint_paths);
546+
j.at("tbe_uuid").get_to(tbe_uuid);
547+
j.at("num_shards").get_to(num_shards);
548+
j.at("num_threads").get_to(num_threads);
549+
j.at("max_D").get_to(max_D);
550+
j.at("dtype").get_to(dtype);
551+
552+
// initialize ro rdb during KV tensor deserialization
553+
// one rdb checkpoint is related to # tables of KVT, this way each KVT will
554+
// hold their own rdb instance link to the same checkpoint during destruction,
555+
// they will delete the same checkpoint, but since ckpt path has been opened
556+
// during ro rdb init, OS will not delete the file until all file handles are
557+
// closed
558+
kvt.readonly_db_ = std::make_shared<ReadOnlyEmbeddingKVDB>(
559+
rdb_shard_checkpoint_paths, tbe_uuid, num_shards, num_threads, max_D);
560+
j.at("checkpoint_uuid").get_to(kvt.checkpoint_uuid);
561+
j.at("row_offset").get_to(kvt.row_offset_);
562+
j.at("shape").get_to(kvt.shape_);
563+
kvt.options_ = at::TensorOptions()
564+
.dtype(static_cast<at::ScalarType>(dtype))
565+
.device(at::kCPU)
566+
.layout(at::kStrided);
567+
}
568+
457569
at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) {
458570
CHECK_TRUE(db_ != nullptr);
459571
CHECK_GE(db_->get_max_D(), shape_[1]);

0 commit comments

Comments
 (0)