31
31
#include " parquet/encryption/openssl_internal.h"
32
32
#include " parquet/exception.h"
33
33
34
+ using ::arrow::util::span;
34
35
using parquet::ParquetException;
35
36
36
37
namespace parquet ::encryption {
@@ -315,8 +316,8 @@ class AesDecryptor::AesDecryptorImpl {
315
316
316
317
~AesDecryptorImpl () { WipeOut (); }
317
318
318
- int Decrypt (const uint8_t * ciphertext, int ciphertext_len, const uint8_t * key,
319
- int key_len, const uint8_t * aad, int aad_len, uint8_t * plaintext);
319
+ int Decrypt (span< const uint8_t > ciphertext, span< const uint8_t > key,
320
+ span< const uint8_t > aad, span< uint8_t > plaintext);
320
321
321
322
void WipeOut () {
322
323
if (nullptr != ctx_) {
@@ -325,25 +326,46 @@ class AesDecryptor::AesDecryptorImpl {
325
326
}
326
327
}
327
328
328
- int ciphertext_size_delta () { return ciphertext_size_delta_; }
329
+ [[nodiscard]] int PlaintextLength (int ciphertext_len) const {
330
+ if (ciphertext_len < ciphertext_size_delta_) {
331
+ std::stringstream ss;
332
+ ss << " Ciphertext length " << ciphertext_len << " is invalid, expected at least "
333
+ << ciphertext_size_delta_;
334
+ throw ParquetException (ss.str ());
335
+ }
336
+ return ciphertext_len - ciphertext_size_delta_;
337
+ }
338
+
339
+ [[nodiscard]] int CiphertextLength (int plaintext_len) const {
340
+ if (plaintext_len < 0 ) {
341
+ std::stringstream ss;
342
+ ss << " Negative plaintext length " << plaintext_len;
343
+ throw ParquetException (ss.str ());
344
+ }
345
+ return plaintext_len + ciphertext_size_delta_;
346
+ }
329
347
330
348
private:
331
349
EVP_CIPHER_CTX* ctx_;
332
350
int aes_mode_;
333
351
int key_length_;
334
352
int ciphertext_size_delta_;
335
353
int length_buffer_length_;
336
- int GcmDecrypt (const uint8_t * ciphertext, int ciphertext_len, const uint8_t * key,
337
- int key_len, const uint8_t * aad, int aad_len, uint8_t * plaintext);
338
354
339
- int CtrDecrypt (const uint8_t * ciphertext, int ciphertext_len, const uint8_t * key,
340
- int key_len, uint8_t * plaintext);
355
+ // / Get the actual ciphertext length, inclusive of the length buffer length,
356
+ // / and validate that the provided buffer size is large enough.
357
+ [[nodiscard]] int GetCiphertextLength (span<const uint8_t > ciphertext) const ;
358
+
359
+ int GcmDecrypt (span<const uint8_t > ciphertext, span<const uint8_t > key,
360
+ span<const uint8_t > aad, span<uint8_t > plaintext);
361
+
362
+ int CtrDecrypt (span<const uint8_t > ciphertext, span<const uint8_t > key,
363
+ span<uint8_t > plaintext);
341
364
};
342
365
343
- int AesDecryptor::Decrypt (const uint8_t * plaintext, int plaintext_len, const uint8_t * key,
344
- int key_len, const uint8_t * aad, int aad_len,
345
- uint8_t * ciphertext) {
346
- return impl_->Decrypt (plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext);
366
+ int AesDecryptor::Decrypt (span<const uint8_t > ciphertext, span<const uint8_t > key,
367
+ span<const uint8_t > aad, span<uint8_t > plaintext) {
368
+ return impl_->Decrypt (ciphertext, key, aad, plaintext);
347
369
}
348
370
349
371
void AesDecryptor::WipeOut () { impl_->WipeOut (); }
@@ -438,56 +460,105 @@ std::shared_ptr<AesDecryptor> AesDecryptor::Make(
438
460
return decryptor;
439
461
}
440
462
441
- int AesDecryptor::CiphertextSizeDelta () { return impl_->ciphertext_size_delta (); }
442
-
443
- int AesDecryptor::AesDecryptorImpl::GcmDecrypt (const uint8_t * ciphertext,
444
- int ciphertext_len, const uint8_t * key,
445
- int key_len, const uint8_t * aad,
446
- int aad_len, uint8_t * plaintext) {
447
- int len;
448
- int plaintext_len;
463
+ int AesDecryptor::PlaintextLength (int ciphertext_len) const {
464
+ return impl_->PlaintextLength (ciphertext_len);
465
+ }
449
466
450
- uint8_t tag[kGcmTagLength ];
451
- memset (tag, 0 , kGcmTagLength );
452
- uint8_t nonce[kNonceLength ];
453
- memset (nonce, 0 , kNonceLength );
467
+ int AesDecryptor::CiphertextLength (int plaintext_len) const {
468
+ return impl_->CiphertextLength (plaintext_len);
469
+ }
454
470
471
+ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength (
472
+ span<const uint8_t > ciphertext) const {
455
473
if (length_buffer_length_ > 0 ) {
474
+ // Note: length_buffer_length_ must be either 0 or kBufferSizeLength
475
+ if (ciphertext.size () < static_cast <size_t >(kBufferSizeLength )) {
476
+ std::stringstream ss;
477
+ ss << " Ciphertext buffer length " << ciphertext.size ()
478
+ << " is insufficient to read the ciphertext length."
479
+ << " At least " << kBufferSizeLength << " bytes are required." ;
480
+ throw ParquetException (ss.str ());
481
+ }
482
+
456
483
// Extract ciphertext length
457
484
int written_ciphertext_len = ((ciphertext[3 ] & 0xff ) << 24 ) |
458
485
((ciphertext[2 ] & 0xff ) << 16 ) |
459
486
((ciphertext[1 ] & 0xff ) << 8 ) | ((ciphertext[0 ] & 0xff ));
460
487
461
- if (ciphertext_len > 0 &&
462
- ciphertext_len != (written_ciphertext_len + length_buffer_length_)) {
463
- throw ParquetException (" Wrong ciphertext length" );
488
+ if (written_ciphertext_len < 0 ) {
489
+ std::stringstream ss;
490
+ ss << " Negative ciphertext length " << written_ciphertext_len;
491
+ throw ParquetException (ss.str ());
492
+ } else if (ciphertext.size () <
493
+ static_cast <size_t >(written_ciphertext_len) + length_buffer_length_) {
494
+ std::stringstream ss;
495
+ ss << " Serialized ciphertext length "
496
+ << (written_ciphertext_len + length_buffer_length_)
497
+ << " is greater than the provided ciphertext buffer length "
498
+ << ciphertext.size ();
499
+ throw ParquetException (ss.str ());
464
500
}
465
- ciphertext_len = written_ciphertext_len + length_buffer_length_;
501
+
502
+ return written_ciphertext_len + length_buffer_length_;
466
503
} else {
467
- if (ciphertext_len == 0 ) {
468
- throw ParquetException (" Zero ciphertext length" );
504
+ if (ciphertext.size () > static_cast <size_t >(std::numeric_limits<int32_t >::max ())) {
505
+ std::stringstream ss;
506
+ ss << " Ciphertext buffer length " << ciphertext.size () << " overflows int32" ;
507
+ throw ParquetException (ss.str ());
469
508
}
509
+ return static_cast <int >(ciphertext.size ());
510
+ }
511
+ }
512
+
513
+ int AesDecryptor::AesDecryptorImpl::GcmDecrypt (span<const uint8_t > ciphertext,
514
+ span<const uint8_t > key,
515
+ span<const uint8_t > aad,
516
+ span<uint8_t > plaintext) {
517
+ int len;
518
+ int plaintext_len;
519
+
520
+ uint8_t tag[kGcmTagLength ];
521
+ memset (tag, 0 , kGcmTagLength );
522
+ uint8_t nonce[kNonceLength ];
523
+ memset (nonce, 0 , kNonceLength );
524
+
525
+ int ciphertext_len = GetCiphertextLength (ciphertext);
526
+
527
+ if (plaintext.size () < static_cast <size_t >(ciphertext_len) - ciphertext_size_delta_) {
528
+ std::stringstream ss;
529
+ ss << " Plaintext buffer length " << plaintext.size () << " is insufficient "
530
+ << " for ciphertext length " << ciphertext_len;
531
+ throw ParquetException (ss.str ());
532
+ }
533
+
534
+ if (ciphertext_len < length_buffer_length_ + kNonceLength + kGcmTagLength ) {
535
+ std::stringstream ss;
536
+ ss << " Invalid ciphertext length " << ciphertext_len << " . Expected at least "
537
+ << length_buffer_length_ + kNonceLength + kGcmTagLength << " \n " ;
538
+ throw ParquetException (ss.str ());
470
539
}
471
540
472
541
// Extracting IV and tag
473
- std::copy (ciphertext + length_buffer_length_,
474
- ciphertext + length_buffer_length_ + kNonceLength , nonce);
475
- std::copy (ciphertext + ciphertext_len - kGcmTagLength , ciphertext + ciphertext_len ,
476
- tag);
542
+ std::copy (ciphertext. begin () + length_buffer_length_,
543
+ ciphertext. begin () + length_buffer_length_ + kNonceLength , nonce);
544
+ std::copy (ciphertext. begin () + ciphertext_len - kGcmTagLength ,
545
+ ciphertext. begin () + ciphertext_len, tag);
477
546
478
547
// Setting key and IV
479
- if (1 != EVP_DecryptInit_ex (ctx_, nullptr , nullptr , key, nonce)) {
548
+ if (1 != EVP_DecryptInit_ex (ctx_, nullptr , nullptr , key. data () , nonce)) {
480
549
throw ParquetException (" Couldn't set key and IV" );
481
550
}
482
551
483
552
// Setting additional authenticated data
484
- if ((nullptr != aad) && (1 != EVP_DecryptUpdate (ctx_, nullptr , &len, aad, aad_len))) {
553
+ if ((!aad.empty ()) && (1 != EVP_DecryptUpdate (ctx_, nullptr , &len, aad.data (),
554
+ static_cast <int >(aad.size ())))) {
485
555
throw ParquetException (" Couldn't set AAD" );
486
556
}
487
557
488
558
// Decryption
489
559
if (!EVP_DecryptUpdate (
490
- ctx_, plaintext, &len, ciphertext + length_buffer_length_ + kNonceLength ,
560
+ ctx_, plaintext.data (), &len,
561
+ ciphertext.data () + length_buffer_length_ + kNonceLength ,
491
562
ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength )) {
492
563
throw ParquetException (" Failed decryption update" );
493
564
}
@@ -500,87 +571,86 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(const uint8_t* ciphertext,
500
571
}
501
572
502
573
// Finalization
503
- if (1 != EVP_DecryptFinal_ex (ctx_, plaintext + len, &len)) {
574
+ if (1 != EVP_DecryptFinal_ex (ctx_, plaintext. data () + len, &len)) {
504
575
throw ParquetException (" Failed decryption finalization" );
505
576
}
506
577
507
578
plaintext_len += len;
508
579
return plaintext_len;
509
580
}
510
581
511
- int AesDecryptor::AesDecryptorImpl::CtrDecrypt (const uint8_t * ciphertext,
512
- int ciphertext_len, const uint8_t * key,
513
- int key_len, uint8_t * plaintext) {
582
+ int AesDecryptor::AesDecryptorImpl::CtrDecrypt (span< const uint8_t > ciphertext,
583
+ span< const uint8_t > key,
584
+ span< uint8_t > plaintext) {
514
585
int len;
515
586
int plaintext_len;
516
587
517
588
uint8_t iv[kCtrIvLength ];
518
589
memset (iv, 0 , kCtrIvLength );
519
590
520
- if (length_buffer_length_ > 0 ) {
521
- // Extract ciphertext length
522
- int written_ciphertext_len = ((ciphertext[3 ] & 0xff ) << 24 ) |
523
- ((ciphertext[2 ] & 0xff ) << 16 ) |
524
- ((ciphertext[1 ] & 0xff ) << 8 ) | ((ciphertext[0 ] & 0xff ));
591
+ int ciphertext_len = GetCiphertextLength (ciphertext);
525
592
526
- if (ciphertext_len > 0 &&
527
- ciphertext_len != (written_ciphertext_len + length_buffer_length_)) {
528
- throw ParquetException (" Wrong ciphertext length" );
529
- }
530
- ciphertext_len = written_ciphertext_len;
531
- } else {
532
- if (ciphertext_len == 0 ) {
533
- throw ParquetException (" Zero ciphertext length" );
534
- }
593
+ if (plaintext.size () < static_cast <size_t >(ciphertext_len) - ciphertext_size_delta_) {
594
+ std::stringstream ss;
595
+ ss << " Plaintext buffer length " << plaintext.size () << " is insufficient "
596
+ << " for ciphertext length " << ciphertext_len;
597
+ throw ParquetException (ss.str ());
598
+ }
599
+
600
+ if (ciphertext_len < length_buffer_length_ + kNonceLength ) {
601
+ std::stringstream ss;
602
+ ss << " Invalid ciphertext length " << ciphertext_len << " . Expected at least "
603
+ << length_buffer_length_ + kNonceLength << " \n " ;
604
+ throw ParquetException (ss.str ());
535
605
}
536
606
537
607
// Extracting nonce
538
- std::copy (ciphertext + length_buffer_length_,
539
- ciphertext + length_buffer_length_ + kNonceLength , iv);
608
+ std::copy (ciphertext. begin () + length_buffer_length_,
609
+ ciphertext. begin () + length_buffer_length_ + kNonceLength , iv);
540
610
// Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial
541
611
// counter field.
542
612
// The first 31 bits of the initial counter field are set to 0, the last bit
543
613
// is set to 1.
544
614
iv[kCtrIvLength - 1 ] = 1 ;
545
615
546
616
// Setting key and IV
547
- if (1 != EVP_DecryptInit_ex (ctx_, nullptr , nullptr , key, iv)) {
617
+ if (1 != EVP_DecryptInit_ex (ctx_, nullptr , nullptr , key. data () , iv)) {
548
618
throw ParquetException (" Couldn't set key and IV" );
549
619
}
550
620
551
621
// Decryption
552
- if (!EVP_DecryptUpdate (ctx_, plaintext, &len,
553
- ciphertext + length_buffer_length_ + kNonceLength ,
554
- ciphertext_len - kNonceLength )) {
622
+ if (!EVP_DecryptUpdate (ctx_, plaintext. data () , &len,
623
+ ciphertext. data () + length_buffer_length_ + kNonceLength ,
624
+ ciphertext_len - length_buffer_length_ - kNonceLength )) {
555
625
throw ParquetException (" Failed decryption update" );
556
626
}
557
627
558
628
plaintext_len = len;
559
629
560
630
// Finalization
561
- if (1 != EVP_DecryptFinal_ex (ctx_, plaintext + len, &len)) {
631
+ if (1 != EVP_DecryptFinal_ex (ctx_, plaintext. data () + len, &len)) {
562
632
throw ParquetException (" Failed decryption finalization" );
563
633
}
564
634
565
635
plaintext_len += len;
566
636
return plaintext_len;
567
637
}
568
638
569
- int AesDecryptor::AesDecryptorImpl::Decrypt (const uint8_t * ciphertext, int ciphertext_len ,
570
- const uint8_t * key, int key_len ,
571
- const uint8_t * aad, int aad_len ,
572
- uint8_t * plaintext) {
573
- if (key_length_ != key_len ) {
639
+ int AesDecryptor::AesDecryptorImpl::Decrypt (span< const uint8_t > ciphertext,
640
+ span< const uint8_t > key,
641
+ span< const uint8_t > aad,
642
+ span< uint8_t > plaintext) {
643
+ if (static_cast < size_t >( key_length_) != key. size () ) {
574
644
std::stringstream ss;
575
- ss << " Wrong key length " << key_len << " . Should be " << key_length_;
645
+ ss << " Wrong key length " << key. size () << " . Should be " << key_length_;
576
646
throw ParquetException (ss.str ());
577
647
}
578
648
579
649
if (kGcmMode == aes_mode_) {
580
- return GcmDecrypt (ciphertext, ciphertext_len, key, key_len, aad, aad_len , plaintext);
650
+ return GcmDecrypt (ciphertext, key, aad, plaintext);
581
651
}
582
652
583
- return CtrDecrypt (ciphertext, ciphertext_len, key, key_len , plaintext);
653
+ return CtrDecrypt (ciphertext, key, plaintext);
584
654
}
585
655
586
656
static std::string ShortToBytesLe (int16_t input) {
0 commit comments