diff --git a/Benchmarks/Benchmarks/BaseN/BaseN.swift b/Benchmarks/Benchmarks/BaseN/BaseN.swift index 87bb2d8..ea3ce01 100644 --- a/Benchmarks/Benchmarks/BaseN/BaseN.swift +++ b/Benchmarks/Benchmarks/BaseN/BaseN.swift @@ -33,6 +33,17 @@ let benchmarks = { } } + Benchmark("Base32.decodeIgnoreNullCharacters") { benchmark in + let bytes = Array(UInt8(0) ... UInt8(255)) + let base32 = Base32.encodeToString(bytes: bytes) + + benchmark.startMeasurement() + + for _ in benchmark.scaledIterations { + try blackHole(Base32.decode(string: base32, options: .allowNullCharacters)) + } + } + Benchmark("Base64.encode") { benchmark in let bytes = Array(UInt8(0) ... UInt8(255)) diff --git a/Sources/ExtrasBase64/Base32.swift b/Sources/ExtrasBase64/Base32.swift index c349f09..0cbaf0e 100644 --- a/Sources/ExtrasBase64/Base32.swift +++ b/Sources/ExtrasBase64/Base32.swift @@ -6,9 +6,9 @@ public extension String { self = Base32.encodeToString(bytes: bytes, options: options) } - /// Decode base32 encoded strin - func base32decoded() throws -> [UInt8] { - try Base32.decode(string: self) + /// Decode base32 encoded string + func base32decoded(options: Base32.DecodingOptions = []) throws -> [UInt8] { + try Base32.decode(string: self, options: options) } } @@ -22,8 +22,25 @@ public enum Base32 { public static let omitPaddingCharacter = EncodingOptions(rawValue: UInt(1 << 0)) } - public enum DecodingError: Swift.Error, Equatable { - case invalidCharacter(UInt8) + /// Decoding options + public struct DecodingOptions: OptionSet { + public let rawValue: UInt + public init(rawValue: UInt) { self.rawValue = rawValue } + + public static let allowNullCharacters = DecodingOptions(rawValue: UInt(1 << 0)) + } + + public struct DecodingError: Swift.Error, Equatable { + enum _Internal { + case invalidCharacter + } + + fileprivate let value: _Internal + init(_ value: _Internal) { + self.value = value + } + + public static var invalidCharacter: Self { .init(.invalidCharacter) } } /// Base32 Encode a buffer to an array of bytes @@ -70,7 +87,10 @@ public enum Base32 { } /// Base32 decode string - public static func decode(string encoded: String) throws -> [UInt8] { + public static func decode( + string encoded: String, + options: DecodingOptions = [] + ) throws -> [UInt8] { let decoded = try encoded.utf8.withContiguousStorageIfAvailable { characterPointer -> [UInt8] in guard characterPointer.count > 0 else { return [] @@ -80,7 +100,11 @@ public enum Base32 { return try characterPointer.withMemoryRebound(to: UInt8.self) { input -> [UInt8] in try [UInt8](unsafeUninitializedCapacity: capacity) { output, length in - length = try Self._decode(from: input, into: output) + if options.contains(.allowNullCharacters) { + length = try Self._decode(from: input[...], into: output[...]) + } else { + length = try Self._strictDecode(from: input, into: output) + } } } } @@ -95,7 +119,10 @@ public enum Base32 { } /// Base32 decode a buffer to an array of UInt8 - public static func decode(bytes: Buffer) throws -> [UInt8] where Buffer.Element == UInt8 { + public static func decode( + bytes: Buffer, + options: DecodingOptions = [] + ) throws -> [UInt8] where Buffer.Element == UInt8 { guard bytes.count > 0 else { return [] } @@ -104,7 +131,11 @@ public enum Base32 { let outputLength = ((input.count + 7) / 8) * 5 return try [UInt8](unsafeUninitializedCapacity: outputLength) { output, length in - length = try Self._decode(from: input, into: output) + if options.contains(.allowNullCharacters) { + length = try Self._decode(from: input[...], into: output[...]) + } else { + length = try Self._strictDecode(from: input, into: output) + } } } @@ -153,6 +184,42 @@ extension Base32 { /* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, ] + private static let strictDecodeTable: [UInt8] = [ + /* 00 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 08 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 10 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 18 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 20 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 28 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 30 */ 0x80, 0x80, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + /* 38 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0xC0, 0x80, 0x80, + /* 40 */ 0x80, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + /* 48 */ 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + /* 50 */ 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + /* 58 */ 0x17, 0x18, 0x19, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 60 */ 0x80, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + /* 68 */ 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + /* 60 */ 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + /* 68 */ 0x17, 0x18, 0x19, 0x80, 0x80, 0x80, 0x80, 0x80, + + /* 80 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 88 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 90 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 98 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* A0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* A8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* B0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* B8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* C0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* C8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* D0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* D8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* E0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* E8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* F0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + ] + private static let encodeTable: [UInt8] = [ /* 00 */ 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, /* 08 */ 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, @@ -160,20 +227,55 @@ extension Base32 { /* 18 */ 0x59, 0x5A, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, ] - private static func _decode(from input: UnsafeBufferPointer, into output: UnsafeMutableBufferPointer) throws -> Int { + /// Decode Base32 assuming there are no null characters + private static func _strictDecode(from input: UnsafeBufferPointer, into output: UnsafeMutableBufferPointer) throws -> Int { guard input.count != 0 else { return 0 } - var bitsLeft = 0 - var buffer: UInt32 = 0 var outputIndex = 0 + // work out how many blocks can go through the fast path. Last block + // should be passed to the slow path + let inputMinusLastBlock = (input.count - 1) & ~0x7 var i = 0 - loop: while i < input.count { + while i < inputMinusLastBlock { + let v1 = self.strictDecodeTable[Int(input[i])] + let v2 = self.strictDecodeTable[Int(input[i + 1])] + let v3 = self.strictDecodeTable[Int(input[i + 2])] + let v4 = self.strictDecodeTable[Int(input[i + 3])] + let v5 = self.strictDecodeTable[Int(input[i + 4])] + let v6 = self.strictDecodeTable[Int(input[i + 5])] + let v7 = self.strictDecodeTable[Int(input[i + 6])] + let v8 = self.strictDecodeTable[Int(input[i + 7])] + let vCombined = v1 | v2 | v3 | v4 | v5 | v6 | v7 | v8 + if (vCombined & ~0x1F) != 0 { + throw DecodingError.invalidCharacter + } + i += 8 + output[outputIndex] = (v1 << 3) | (v2 >> 2) + output[outputIndex + 1] = (v2 << 6) | (v3 << 1) | (v4 >> 4) + output[outputIndex + 2] = (v4 << 4) | (v5 >> 1) + output[outputIndex + 3] = (v5 << 7) | (v6 << 2) | (v7 >> 3) + output[outputIndex + 4] = (v7 << 5) | v8 + outputIndex += 5 + } + + return try self._decode(from: input[i...], into: output[outputIndex...]) + } + + /// Decode Base32 with the possibility of null characters or padding + private static func _decode(from input: UnsafeBufferPointer.SubSequence, into output: UnsafeMutableBufferPointer.SubSequence) throws -> Int { + guard input.count != 0 else { return output.startIndex } + var output = output + var bitsLeft = 0 + var buffer: UInt32 = 0 + var outputIndex = output.startIndex + var i = input.startIndex + loop: while i < input.endIndex { let index = Int(input[i]) i += 1 let v = self.decodeTable[index] switch v { case 0x80: - throw DecodingError.invalidCharacter(UInt8(index)) + throw DecodingError.invalidCharacter case 0x40: continue case 0xC0: @@ -191,9 +293,9 @@ extension Base32 { } } // Any characters left should be padding - while i < input.count { + while i < input.endIndex { let index = Int(input[i]) - guard self.decodeTable[index] == 0xC0 else { throw DecodingError.invalidCharacter(UInt8(index)) } + guard self.decodeTable[index] == 0xC0 else { throw DecodingError.invalidCharacter } i += 1 } return outputIndex diff --git a/Sources/ExtrasBase64/Base64.swift b/Sources/ExtrasBase64/Base64.swift index ba2a557..b75328b 100644 --- a/Sources/ExtrasBase64/Base64.swift +++ b/Sources/ExtrasBase64/Base64.swift @@ -334,11 +334,23 @@ extension Base64 { public static let omitPaddingCharacter = DecodingOptions(rawValue: UInt(1 << 1)) } - public enum DecodingError: Error, Equatable { - case invalidLength - case invalidCharacter(UInt8) - case unexpectedPaddingCharacter - case unexpectedEnd + public struct DecodingError: Error, Equatable { + fileprivate enum _Internal: Error, Equatable { + case invalidLength + case invalidCharacter(UInt8) + case unexpectedPaddingCharacter + case unexpectedEnd + } + + fileprivate let value: _Internal + fileprivate init(_ value: _Internal) { + self.value = value + } + + public static var invalidLength: Self { .init(.invalidLength) } + public static func invalidCharacter(_ character: UInt8) -> Self { .init(.invalidCharacter(character)) } + public static var unexpectedPaddingCharacter: Self { .init(.unexpectedPaddingCharacter) } + public static var unexpectedEnd: Self { .init(.unexpectedEnd) } } @inlinable diff --git a/Tests/ExtrasBase64Tests/Base32Tests.swift b/Tests/ExtrasBase64Tests/Base32Tests.swift index 72958fe..5ec3de3 100644 --- a/Tests/ExtrasBase64Tests/Base32Tests.swift +++ b/Tests/ExtrasBase64Tests/Base32Tests.swift @@ -54,9 +54,29 @@ class Base32Tests: XCTestCase { XCTAssertEqual(decoded, expected) } + func testBase32DecodingWithNullCharacters() { + let base32 = """ + AAAQEAYEAUDAOCAJBIFQYDIOB4IBCEQTCQKRMFYYDENBWHA5D + YPSAIJCEMSCKJRHFAUSUKZMFUXC6MBRGIZTINJWG44DSOR3HQ + 6T4P2AIFBEGRCFIZDUQSKKJNGE2TSPKBIVEU2UKVLFOWCZLJN + VYXK6L5QGCYTDMRSWMZ3INFVGW3DNNZXXA4LSON2HK5TXPB4X + U634PV7H7AEBQKBYJBMGQ6EITCULRSGY5D4QSGJJHFEVS2LZR + GM2TOOJ3HU7UCQ2FI5EUWTKPKFJVKV2ZLNOV6YLDMVTWS23NN + 5YXG5LXPF5X274BQOCYPCMLRWHZDE4VS6MZXHM7UGR2LJ5JVO + W27MNTWW33TO55X7A4HROHZHF43T6R2PK5PWO33XP6DY7F47U + 6X3PP6HZ7L57Z7P674 + """ + + let expected = Array(UInt8(0) ... UInt8(255)) + var decoded: [UInt8]? + XCTAssertNoThrow(decoded = try Base32.decode(bytes: base32.utf8, options: .allowNullCharacters)) + XCTAssertEqual(decoded, expected) + XCTAssertThrowsError(decoded = try Base32.decode(bytes: base32.utf8)) { _ in } + } + func testBase32DecodingWithPoop() { XCTAssertThrowsError(_ = try Base32.decode(bytes: "💩".utf8)) { error in - XCTAssertEqual(error as? Base32.DecodingError, .invalidCharacter(240)) + XCTAssertEqual(error as? Base32.DecodingError, .invalidCharacter) } } @@ -103,7 +123,6 @@ class Base32Tests: XCTestCase { } func testBase32EncodeFoobarWithPadding() { - XCTAssertEqual(String(base32Encoding: "".utf8), "") XCTAssertEqual(String(base32Encoding: "f".utf8), "MY======") XCTAssertEqual(String(base32Encoding: "fo".utf8), "MZXQ====") XCTAssertEqual(String(base32Encoding: "foo".utf8), "MZXW6===") @@ -111,4 +130,34 @@ class Base32Tests: XCTestCase { XCTAssertEqual(String(base32Encoding: "fooba".utf8), "MZXW6YTB") XCTAssertEqual(String(base32Encoding: "foobar".utf8), "MZXW6YTBOI======") } + + func testBase32DecodeFoobar() { + XCTAssertEqual(try "".base32decoded(), .init("".utf8)) + XCTAssertEqual(try "MY".base32decoded(), .init("f".utf8)) + XCTAssertEqual(try "MZXQ".base32decoded(), .init("fo".utf8)) + XCTAssertEqual(try "MZXW6".base32decoded(), .init("foo".utf8)) + XCTAssertEqual(try "MZXW6YQ".base32decoded(), .init("foob".utf8)) + XCTAssertEqual(try "MZXW6YTB".base32decoded(), .init("fooba".utf8)) + XCTAssertEqual(try "MZXW6YTBOI".base32decoded(), .init("foobar".utf8)) + } + + func testBase32DecodeFoobarWithPadding() { + XCTAssertEqual(try "MY======".base32decoded(), .init("f".utf8)) + XCTAssertEqual(try "MZXQ====".base32decoded(), .init("fo".utf8)) + XCTAssertEqual(try "MZXW6===".base32decoded(), .init("foo".utf8)) + XCTAssertEqual(try "MZXW6YQ=".base32decoded(), .init("foob".utf8)) + XCTAssertEqual(try "MZXW6YTB".base32decoded(), .init("fooba".utf8)) + XCTAssertEqual(try "MZXW6YTBOI======".base32decoded(), .init("foobar".utf8)) + } + + func testBase32EncodeDecode() throws { + for _ in 0 ..< 100 { + let buffer: [UInt8] = (0 ..< Int.random(in: 1 ..< 8192)).map { _ in UInt8.random(in: .min ... .max) } + let base32 = String(base32Encoding: buffer) + let buffer2 = try base32.base32decoded(options: .allowNullCharacters) + let buffer3 = try base32.base32decoded() + XCTAssertEqual(buffer, buffer2) + XCTAssertEqual(buffer, buffer3) + } + } }