|
11 | 11 | #include <c10/core/ScalarTypeToTypeMeta.h>
|
12 | 12 | #include <torch/library.h>
|
13 | 13 |
|
| 14 | +#include <nlohmann/json.hpp> |
14 | 15 | #include <torch/custom_class.h>
|
15 | 16 | #include <mutex>
|
16 | 17 | #include "../dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h"
|
@@ -324,6 +325,10 @@ CheckpointHandle::CheckpointHandle(
|
324 | 325 | }
|
325 | 326 | }
|
326 | 327 |
|
| 328 | +std::vector<std::string> CheckpointHandle::get_shard_checkpoints() const { |
| 329 | + return shard_checkpoints_; |
| 330 | +} |
| 331 | + |
327 | 332 | EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
|
328 | 333 | const SnapshotHandle* handle,
|
329 | 334 | std::shared_ptr<EmbeddingRocksDB> db)
|
@@ -377,6 +382,64 @@ KVTensorWrapper::KVTensorWrapper(
|
377 | 382 | checkpoint_handle_ = checkpoint_handle;
|
378 | 383 | }
|
379 | 384 |
|
| 385 | +std::string KVTensorWrapper::serialize() const { |
| 386 | + // auto call to_json() |
| 387 | + ssd::json json_serialized = *this; |
| 388 | + return json_serialized.dump(); |
| 389 | +} |
| 390 | + |
| 391 | +std::string KVTensorWrapper::logs() const { |
| 392 | + std::stringstream ss; |
| 393 | + if (db_) { |
| 394 | + CHECK(readonly_db_ == nullptr) << "rdb logs, ro_rdb must be nullptr"; |
| 395 | + ss << "from ckpt paths: " << std::endl; |
| 396 | + // Required to cast as the KVTensorWrapper.db_ is a pointer for the |
| 397 | + // EmbeddingKVDB class which is inherited by the EmbeddingRocksDB class |
| 398 | + auto* db = dynamic_cast<EmbeddingRocksDB*>(db_.get()); |
| 399 | + auto ckpts = db->get_checkpoints(checkpoint_handle_->uuid); |
| 400 | + for (int i = 0; i < ckpts.size(); i++) { |
| 401 | + ss << " shard:" << i << ", ckpt_path:" << ckpts[i] << std::endl; |
| 402 | + } |
| 403 | + ss << " tbe_uuid: " << db->get_tbe_uuid() << std::endl; |
| 404 | + ss << " num_shards: " << db->num_shards() << std::endl; |
| 405 | + ss << " num_threads: " << db->num_threads() << std::endl; |
| 406 | + ss << " max_D: " << db->get_max_D() << std::endl; |
| 407 | + ss << " row_offset: " << row_offset_ << std::endl; |
| 408 | + ss << " shape: " << shape_ << std::endl; |
| 409 | + ss << " dtype: " << static_cast<int64_t>(options_.dtype().toScalarType()) |
| 410 | + << std::endl; |
| 411 | + ss << " checkpoint_uuid: " << checkpoint_handle_->uuid << std::endl; |
| 412 | + } else { |
| 413 | + CHECK(readonly_db_) << "ro_rdb logs, ro_rdb must be valid"; |
| 414 | + ss << "from ckpt paths: " << std::endl; |
| 415 | + auto* db = dynamic_cast<ReadOnlyEmbeddingKVDB*>(readonly_db_.get()); |
| 416 | + auto rdb_shard_checkpoint_paths = db->get_rdb_shard_checkpoint_paths(); |
| 417 | + for (int i = 0; i < rdb_shard_checkpoint_paths.size(); i++) { |
| 418 | + ss << " shard:" << i << ", ckpt_path:" << rdb_shard_checkpoint_paths[i] |
| 419 | + << std::endl; |
| 420 | + } |
| 421 | + ss << " tbe_uuid: " << db->get_tbe_uuid() << std::endl; |
| 422 | + ss << " num_shards: " << db->num_shards() << std::endl; |
| 423 | + ss << " num_threads: " << db->num_threads() << std::endl; |
| 424 | + ss << " max_D: " << db->get_max_D() << std::endl; |
| 425 | + ss << " row_offset: " << row_offset_ << std::endl; |
| 426 | + ss << " shape: " << shape_ << std::endl; |
| 427 | + ss << " dtype: " << static_cast<int64_t>(options_.dtype().toScalarType()) |
| 428 | + << std::endl; |
| 429 | + ss << " checkpoint_uuid: " << checkpoint_uuid << std::endl; |
| 430 | + } |
| 431 | + return ss.str(); |
| 432 | +} |
| 433 | + |
| 434 | +void KVTensorWrapper::deserialize(const std::string& serialized) { |
| 435 | + ssd::json json_serialized = ssd::json::parse(serialized); |
| 436 | + from_json(json_serialized, *this); |
| 437 | +} |
| 438 | + |
| 439 | +KVTensorWrapper::KVTensorWrapper(const std::string& serialized) { |
| 440 | + deserialize(serialized); |
| 441 | +} |
| 442 | + |
380 | 443 | void KVTensorWrapper::set_embedding_rocks_dp_wrapper(
|
381 | 444 | c10::intrusive_ptr<EmbeddingRocksDBWrapper> db) {
|
382 | 445 | db_ = db->impl_;
|
@@ -454,6 +517,55 @@ void KVTensorWrapper::set_weights_and_ids(
|
454 | 517 | }
|
455 | 518 | }
|
456 | 519 |
|
| 520 | +void to_json(ssd::json& j, const KVTensorWrapper& kvt) { |
| 521 | + // Required to cast as the KVTensorWrapper.db_ is a pointer for the |
| 522 | + // EmbeddingKVDB class which is inherited by the EmbeddingRocksDB class |
| 523 | + std::shared_ptr<EmbeddingRocksDB> db = |
| 524 | + std::dynamic_pointer_cast<EmbeddingRocksDB>(kvt.db_); |
| 525 | + j = ssd::json{ |
| 526 | + {"rdb_shard_checkpoint_paths", |
| 527 | + db->get_checkpoints(kvt.checkpoint_handle_->uuid)}, |
| 528 | + {"tbe_uuid", db->get_tbe_uuid()}, |
| 529 | + {"num_shards", db->num_shards()}, |
| 530 | + {"num_threads", db->num_threads()}, |
| 531 | + {"max_D", db->get_max_D()}, |
| 532 | + {"row_offset", kvt.row_offset_}, |
| 533 | + {"shape", kvt.shape_}, |
| 534 | + {"dtype", static_cast<int64_t>(kvt.options_.dtype().toScalarType())}, |
| 535 | + {"checkpoint_uuid", kvt.checkpoint_handle_->uuid}}; |
| 536 | +} |
| 537 | + |
| 538 | +void from_json(const ssd::json& j, KVTensorWrapper& kvt) { |
| 539 | + std::vector<std::string> rdb_shard_checkpoint_paths; |
| 540 | + std::string tbe_uuid; |
| 541 | + int64_t num_shards; |
| 542 | + int64_t num_threads; |
| 543 | + int64_t max_D; |
| 544 | + int64_t dtype; |
| 545 | + j.at("rdb_shard_checkpoint_paths").get_to(rdb_shard_checkpoint_paths); |
| 546 | + j.at("tbe_uuid").get_to(tbe_uuid); |
| 547 | + j.at("num_shards").get_to(num_shards); |
| 548 | + j.at("num_threads").get_to(num_threads); |
| 549 | + j.at("max_D").get_to(max_D); |
| 550 | + j.at("dtype").get_to(dtype); |
| 551 | + |
| 552 | + // initialize ro rdb during KV tensor deserialization |
| 553 | + // one rdb checkpoint is related to # tables of KVT, this way each KVT will |
| 554 | + // hold their own rdb instance link to the same checkpoint during destruction, |
| 555 | + // they will delete the same checkpoint, but since ckpt path has been opened |
| 556 | + // during ro rdb init, OS will not delete the file until all file handles are |
| 557 | + // closed |
| 558 | + kvt.readonly_db_ = std::make_shared<ReadOnlyEmbeddingKVDB>( |
| 559 | + rdb_shard_checkpoint_paths, tbe_uuid, num_shards, num_threads, max_D); |
| 560 | + j.at("checkpoint_uuid").get_to(kvt.checkpoint_uuid); |
| 561 | + j.at("row_offset").get_to(kvt.row_offset_); |
| 562 | + j.at("shape").get_to(kvt.shape_); |
| 563 | + kvt.options_ = at::TensorOptions() |
| 564 | + .dtype(static_cast<at::ScalarType>(dtype)) |
| 565 | + .device(at::kCPU) |
| 566 | + .layout(at::kStrided); |
| 567 | +} |
| 568 | + |
457 | 569 | at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) {
|
458 | 570 | CHECK_TRUE(db_ != nullptr);
|
459 | 571 | CHECK_GE(db_->get_max_D(), shape_[1]);
|
|
0 commit comments