Skip to content

Commit 1baa3b0

Browse files
adamreeveraulcd
authored andcommitted
GH-43070: [C++][Parquet] Check for valid ciphertext length to prevent segfault (#43071)
### Rationale for this change See #43070 ### What changes are included in this PR? Checks that the ciphertext length is at least enough to hold the length (if written), nonce and GCM tag for the GCM cipher type. Also enforces that the input ciphertext length parameter is provided (is > 0) and verifies that the ciphertext size read from the file isn't going to cause reads beyond the end of the ciphertext buffer. ### Are these changes tested? Yes I've added new unit tests for this. ### Are there any user-facing changes? No * GitHub Issue: #43070 Authored-by: Adam Reeve <[email protected]> Signed-off-by: mwish <[email protected]>
1 parent 6920955 commit 1baa3b0

12 files changed

+349
-102
lines changed

cpp/src/parquet/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ add_parquet_test(arrow-internals-test SOURCES arrow/path_internal_test.cc
408408
if(PARQUET_REQUIRE_ENCRYPTION)
409409
add_parquet_test(encryption-test
410410
SOURCES
411+
encryption/encryption_internal_test.cc
411412
encryption/write_configurations_test.cc
412413
encryption/read_configurations_test.cc
413414
encryption/properties_test.cc

cpp/src/parquet/column_reader.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,11 @@ std::shared_ptr<Page> SerializedPageReader::NextPage() {
512512
// Decrypt it if we need to
513513
if (crypto_ctx_.data_decryptor != nullptr) {
514514
PARQUET_THROW_NOT_OK(decryption_buffer_->Resize(
515-
compressed_len - crypto_ctx_.data_decryptor->CiphertextSizeDelta(),
515+
crypto_ctx_.data_decryptor->PlaintextLength(compressed_len),
516516
/*shrink_to_fit=*/false));
517517
compressed_len = crypto_ctx_.data_decryptor->Decrypt(
518-
page_buffer->data(), compressed_len, decryption_buffer_->mutable_data());
518+
page_buffer->span_as<uint8_t>(),
519+
decryption_buffer_->mutable_span_as<uint8_t>());
519520

520521
page_buffer = decryption_buffer_;
521522
}

cpp/src/parquet/encryption/encryption.h

+8
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ inline const uint8_t* str2bytes(const std::string& str) {
8989
return reinterpret_cast<const uint8_t*>(cbytes);
9090
}
9191

92+
inline ::arrow::util::span<const uint8_t> str2span(const std::string& str) {
93+
if (str.empty()) {
94+
return {};
95+
}
96+
97+
return {reinterpret_cast<const uint8_t*>(str.data()), str.size()};
98+
}
99+
92100
class PARQUET_EXPORT ColumnEncryptionProperties {
93101
public:
94102
class PARQUET_EXPORT Builder {

cpp/src/parquet/encryption/encryption_internal.cc

+139-69
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "parquet/encryption/openssl_internal.h"
3232
#include "parquet/exception.h"
3333

34+
using ::arrow::util::span;
3435
using parquet::ParquetException;
3536

3637
namespace parquet::encryption {
@@ -315,8 +316,8 @@ class AesDecryptor::AesDecryptorImpl {
315316

316317
~AesDecryptorImpl() { WipeOut(); }
317318

318-
int Decrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
319-
int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
319+
int Decrypt(span<const uint8_t> ciphertext, span<const uint8_t> key,
320+
span<const uint8_t> aad, span<uint8_t> plaintext);
320321

321322
void WipeOut() {
322323
if (nullptr != ctx_) {
@@ -325,25 +326,46 @@ class AesDecryptor::AesDecryptorImpl {
325326
}
326327
}
327328

328-
int ciphertext_size_delta() { return ciphertext_size_delta_; }
329+
[[nodiscard]] int PlaintextLength(int ciphertext_len) const {
330+
if (ciphertext_len < ciphertext_size_delta_) {
331+
std::stringstream ss;
332+
ss << "Ciphertext length " << ciphertext_len << " is invalid, expected at least "
333+
<< ciphertext_size_delta_;
334+
throw ParquetException(ss.str());
335+
}
336+
return ciphertext_len - ciphertext_size_delta_;
337+
}
338+
339+
[[nodiscard]] int CiphertextLength(int plaintext_len) const {
340+
if (plaintext_len < 0) {
341+
std::stringstream ss;
342+
ss << "Negative plaintext length " << plaintext_len;
343+
throw ParquetException(ss.str());
344+
}
345+
return plaintext_len + ciphertext_size_delta_;
346+
}
329347

330348
private:
331349
EVP_CIPHER_CTX* ctx_;
332350
int aes_mode_;
333351
int key_length_;
334352
int ciphertext_size_delta_;
335353
int length_buffer_length_;
336-
int GcmDecrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
337-
int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
338354

339-
int CtrDecrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
340-
int key_len, uint8_t* plaintext);
355+
/// Get the actual ciphertext length, inclusive of the length buffer length,
356+
/// and validate that the provided buffer size is large enough.
357+
[[nodiscard]] int GetCiphertextLength(span<const uint8_t> ciphertext) const;
358+
359+
int GcmDecrypt(span<const uint8_t> ciphertext, span<const uint8_t> key,
360+
span<const uint8_t> aad, span<uint8_t> plaintext);
361+
362+
int CtrDecrypt(span<const uint8_t> ciphertext, span<const uint8_t> key,
363+
span<uint8_t> plaintext);
341364
};
342365

343-
int AesDecryptor::Decrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
344-
int key_len, const uint8_t* aad, int aad_len,
345-
uint8_t* ciphertext) {
346-
return impl_->Decrypt(plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext);
366+
int AesDecryptor::Decrypt(span<const uint8_t> ciphertext, span<const uint8_t> key,
367+
span<const uint8_t> aad, span<uint8_t> plaintext) {
368+
return impl_->Decrypt(ciphertext, key, aad, plaintext);
347369
}
348370

349371
void AesDecryptor::WipeOut() { impl_->WipeOut(); }
@@ -438,56 +460,105 @@ std::shared_ptr<AesDecryptor> AesDecryptor::Make(
438460
return decryptor;
439461
}
440462

441-
int AesDecryptor::CiphertextSizeDelta() { return impl_->ciphertext_size_delta(); }
442-
443-
int AesDecryptor::AesDecryptorImpl::GcmDecrypt(const uint8_t* ciphertext,
444-
int ciphertext_len, const uint8_t* key,
445-
int key_len, const uint8_t* aad,
446-
int aad_len, uint8_t* plaintext) {
447-
int len;
448-
int plaintext_len;
463+
int AesDecryptor::PlaintextLength(int ciphertext_len) const {
464+
return impl_->PlaintextLength(ciphertext_len);
465+
}
449466

450-
uint8_t tag[kGcmTagLength];
451-
memset(tag, 0, kGcmTagLength);
452-
uint8_t nonce[kNonceLength];
453-
memset(nonce, 0, kNonceLength);
467+
int AesDecryptor::CiphertextLength(int plaintext_len) const {
468+
return impl_->CiphertextLength(plaintext_len);
469+
}
454470

471+
int AesDecryptor::AesDecryptorImpl::GetCiphertextLength(
472+
span<const uint8_t> ciphertext) const {
455473
if (length_buffer_length_ > 0) {
474+
// Note: length_buffer_length_ must be either 0 or kBufferSizeLength
475+
if (ciphertext.size() < static_cast<size_t>(kBufferSizeLength)) {
476+
std::stringstream ss;
477+
ss << "Ciphertext buffer length " << ciphertext.size()
478+
<< " is insufficient to read the ciphertext length."
479+
<< " At least " << kBufferSizeLength << " bytes are required.";
480+
throw ParquetException(ss.str());
481+
}
482+
456483
// Extract ciphertext length
457484
int written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) |
458485
((ciphertext[2] & 0xff) << 16) |
459486
((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff));
460487

461-
if (ciphertext_len > 0 &&
462-
ciphertext_len != (written_ciphertext_len + length_buffer_length_)) {
463-
throw ParquetException("Wrong ciphertext length");
488+
if (written_ciphertext_len < 0) {
489+
std::stringstream ss;
490+
ss << "Negative ciphertext length " << written_ciphertext_len;
491+
throw ParquetException(ss.str());
492+
} else if (ciphertext.size() <
493+
static_cast<size_t>(written_ciphertext_len) + length_buffer_length_) {
494+
std::stringstream ss;
495+
ss << "Serialized ciphertext length "
496+
<< (written_ciphertext_len + length_buffer_length_)
497+
<< " is greater than the provided ciphertext buffer length "
498+
<< ciphertext.size();
499+
throw ParquetException(ss.str());
464500
}
465-
ciphertext_len = written_ciphertext_len + length_buffer_length_;
501+
502+
return written_ciphertext_len + length_buffer_length_;
466503
} else {
467-
if (ciphertext_len == 0) {
468-
throw ParquetException("Zero ciphertext length");
504+
if (ciphertext.size() > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
505+
std::stringstream ss;
506+
ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int32";
507+
throw ParquetException(ss.str());
469508
}
509+
return static_cast<int>(ciphertext.size());
510+
}
511+
}
512+
513+
int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span<const uint8_t> ciphertext,
514+
span<const uint8_t> key,
515+
span<const uint8_t> aad,
516+
span<uint8_t> plaintext) {
517+
int len;
518+
int plaintext_len;
519+
520+
uint8_t tag[kGcmTagLength];
521+
memset(tag, 0, kGcmTagLength);
522+
uint8_t nonce[kNonceLength];
523+
memset(nonce, 0, kNonceLength);
524+
525+
int ciphertext_len = GetCiphertextLength(ciphertext);
526+
527+
if (plaintext.size() < static_cast<size_t>(ciphertext_len) - ciphertext_size_delta_) {
528+
std::stringstream ss;
529+
ss << "Plaintext buffer length " << plaintext.size() << " is insufficient "
530+
<< "for ciphertext length " << ciphertext_len;
531+
throw ParquetException(ss.str());
532+
}
533+
534+
if (ciphertext_len < length_buffer_length_ + kNonceLength + kGcmTagLength) {
535+
std::stringstream ss;
536+
ss << "Invalid ciphertext length " << ciphertext_len << ". Expected at least "
537+
<< length_buffer_length_ + kNonceLength + kGcmTagLength << "\n";
538+
throw ParquetException(ss.str());
470539
}
471540

472541
// Extracting IV and tag
473-
std::copy(ciphertext + length_buffer_length_,
474-
ciphertext + length_buffer_length_ + kNonceLength, nonce);
475-
std::copy(ciphertext + ciphertext_len - kGcmTagLength, ciphertext + ciphertext_len,
476-
tag);
542+
std::copy(ciphertext.begin() + length_buffer_length_,
543+
ciphertext.begin() + length_buffer_length_ + kNonceLength, nonce);
544+
std::copy(ciphertext.begin() + ciphertext_len - kGcmTagLength,
545+
ciphertext.begin() + ciphertext_len, tag);
477546

478547
// Setting key and IV
479-
if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, nonce)) {
548+
if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), nonce)) {
480549
throw ParquetException("Couldn't set key and IV");
481550
}
482551

483552
// Setting additional authenticated data
484-
if ((nullptr != aad) && (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad, aad_len))) {
553+
if ((!aad.empty()) && (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad.data(),
554+
static_cast<int>(aad.size())))) {
485555
throw ParquetException("Couldn't set AAD");
486556
}
487557

488558
// Decryption
489559
if (!EVP_DecryptUpdate(
490-
ctx_, plaintext, &len, ciphertext + length_buffer_length_ + kNonceLength,
560+
ctx_, plaintext.data(), &len,
561+
ciphertext.data() + length_buffer_length_ + kNonceLength,
491562
ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength)) {
492563
throw ParquetException("Failed decryption update");
493564
}
@@ -500,87 +571,86 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(const uint8_t* ciphertext,
500571
}
501572

502573
// Finalization
503-
if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) {
574+
if (1 != EVP_DecryptFinal_ex(ctx_, plaintext.data() + len, &len)) {
504575
throw ParquetException("Failed decryption finalization");
505576
}
506577

507578
plaintext_len += len;
508579
return plaintext_len;
509580
}
510581

511-
int AesDecryptor::AesDecryptorImpl::CtrDecrypt(const uint8_t* ciphertext,
512-
int ciphertext_len, const uint8_t* key,
513-
int key_len, uint8_t* plaintext) {
582+
int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span<const uint8_t> ciphertext,
583+
span<const uint8_t> key,
584+
span<uint8_t> plaintext) {
514585
int len;
515586
int plaintext_len;
516587

517588
uint8_t iv[kCtrIvLength];
518589
memset(iv, 0, kCtrIvLength);
519590

520-
if (length_buffer_length_ > 0) {
521-
// Extract ciphertext length
522-
int written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) |
523-
((ciphertext[2] & 0xff) << 16) |
524-
((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff));
591+
int ciphertext_len = GetCiphertextLength(ciphertext);
525592

526-
if (ciphertext_len > 0 &&
527-
ciphertext_len != (written_ciphertext_len + length_buffer_length_)) {
528-
throw ParquetException("Wrong ciphertext length");
529-
}
530-
ciphertext_len = written_ciphertext_len;
531-
} else {
532-
if (ciphertext_len == 0) {
533-
throw ParquetException("Zero ciphertext length");
534-
}
593+
if (plaintext.size() < static_cast<size_t>(ciphertext_len) - ciphertext_size_delta_) {
594+
std::stringstream ss;
595+
ss << "Plaintext buffer length " << plaintext.size() << " is insufficient "
596+
<< "for ciphertext length " << ciphertext_len;
597+
throw ParquetException(ss.str());
598+
}
599+
600+
if (ciphertext_len < length_buffer_length_ + kNonceLength) {
601+
std::stringstream ss;
602+
ss << "Invalid ciphertext length " << ciphertext_len << ". Expected at least "
603+
<< length_buffer_length_ + kNonceLength << "\n";
604+
throw ParquetException(ss.str());
535605
}
536606

537607
// Extracting nonce
538-
std::copy(ciphertext + length_buffer_length_,
539-
ciphertext + length_buffer_length_ + kNonceLength, iv);
608+
std::copy(ciphertext.begin() + length_buffer_length_,
609+
ciphertext.begin() + length_buffer_length_ + kNonceLength, iv);
540610
// Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial
541611
// counter field.
542612
// The first 31 bits of the initial counter field are set to 0, the last bit
543613
// is set to 1.
544614
iv[kCtrIvLength - 1] = 1;
545615

546616
// Setting key and IV
547-
if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, iv)) {
617+
if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), iv)) {
548618
throw ParquetException("Couldn't set key and IV");
549619
}
550620

551621
// Decryption
552-
if (!EVP_DecryptUpdate(ctx_, plaintext, &len,
553-
ciphertext + length_buffer_length_ + kNonceLength,
554-
ciphertext_len - kNonceLength)) {
622+
if (!EVP_DecryptUpdate(ctx_, plaintext.data(), &len,
623+
ciphertext.data() + length_buffer_length_ + kNonceLength,
624+
ciphertext_len - length_buffer_length_ - kNonceLength)) {
555625
throw ParquetException("Failed decryption update");
556626
}
557627

558628
plaintext_len = len;
559629

560630
// Finalization
561-
if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) {
631+
if (1 != EVP_DecryptFinal_ex(ctx_, plaintext.data() + len, &len)) {
562632
throw ParquetException("Failed decryption finalization");
563633
}
564634

565635
plaintext_len += len;
566636
return plaintext_len;
567637
}
568638

569-
int AesDecryptor::AesDecryptorImpl::Decrypt(const uint8_t* ciphertext, int ciphertext_len,
570-
const uint8_t* key, int key_len,
571-
const uint8_t* aad, int aad_len,
572-
uint8_t* plaintext) {
573-
if (key_length_ != key_len) {
639+
int AesDecryptor::AesDecryptorImpl::Decrypt(span<const uint8_t> ciphertext,
640+
span<const uint8_t> key,
641+
span<const uint8_t> aad,
642+
span<uint8_t> plaintext) {
643+
if (static_cast<size_t>(key_length_) != key.size()) {
574644
std::stringstream ss;
575-
ss << "Wrong key length " << key_len << ". Should be " << key_length_;
645+
ss << "Wrong key length " << key.size() << ". Should be " << key_length_;
576646
throw ParquetException(ss.str());
577647
}
578648

579649
if (kGcmMode == aes_mode_) {
580-
return GcmDecrypt(ciphertext, ciphertext_len, key, key_len, aad, aad_len, plaintext);
650+
return GcmDecrypt(ciphertext, key, aad, plaintext);
581651
}
582652

583-
return CtrDecrypt(ciphertext, ciphertext_len, key, key_len, plaintext);
653+
return CtrDecrypt(ciphertext, key, plaintext);
584654
}
585655

586656
static std::string ShortToBytesLe(int16_t input) {

0 commit comments

Comments
 (0)