Skip to content

Add Base32 decode fastpath #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Benchmarks/Benchmarks/BaseN/BaseN.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ let benchmarks = {
}
}

Benchmark("Base32.decodeIgnoreNullCharacters") { benchmark in
let bytes = Array(UInt8(0) ... UInt8(255))
let base32 = Base32.encodeToString(bytes: bytes)

benchmark.startMeasurement()

for _ in benchmark.scaledIterations {
try blackHole(Base32.decode(string: base32, options: .allowNullCharacters))
}
}

Benchmark("Base64.encode") { benchmark in
let bytes = Array(UInt8(0) ... UInt8(255))

Expand Down
123 changes: 108 additions & 15 deletions Sources/ExtrasBase64/Base32.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ public extension String {
self = Base32.encodeToString(bytes: bytes, options: options)
}

/// Decode base32 encoded strin
func base32decoded() throws -> [UInt8] {
try Base32.decode(string: self)
/// Decode base32 encoded string
func base32decoded(options: Base32.DecodingOptions = []) throws -> [UInt8] {
try Base32.decode(string: self, options: options)
}
}

Expand All @@ -22,8 +22,16 @@ public enum Base32 {
public static let omitPaddingCharacter = EncodingOptions(rawValue: UInt(1 << 0))
}

/// Decoding options
public struct DecodingOptions: OptionSet {
public let rawValue: UInt
public init(rawValue: UInt) { self.rawValue = rawValue }

public static let allowNullCharacters = DecodingOptions(rawValue: UInt(1 << 0))
}

public enum DecodingError: Swift.Error, Equatable {
case invalidCharacter(UInt8)
case invalidCharacter
}

/// Base32 Encode a buffer to an array of bytes
Expand Down Expand Up @@ -70,7 +78,10 @@ public enum Base32 {
}

/// Base32 decode string
public static func decode(string encoded: String) throws -> [UInt8] {
public static func decode(
string encoded: String,
options: DecodingOptions = []
) throws -> [UInt8] {
let decoded = try encoded.utf8.withContiguousStorageIfAvailable { characterPointer -> [UInt8] in
guard characterPointer.count > 0 else {
return []
Expand All @@ -80,7 +91,11 @@ public enum Base32 {

return try characterPointer.withMemoryRebound(to: UInt8.self) { input -> [UInt8] in
try [UInt8](unsafeUninitializedCapacity: capacity) { output, length in
length = try Self._decode(from: input, into: output)
if options.contains(.allowNullCharacters) {
length = try Self._decode(from: input[...], into: output[...])
} else {
length = try Self._strictDecode(from: input, into: output)
}
}
}
}
Expand All @@ -95,7 +110,10 @@ public enum Base32 {
}

/// Base32 decode a buffer to an array of UInt8
public static func decode<Buffer: Collection>(bytes: Buffer) throws -> [UInt8] where Buffer.Element == UInt8 {
public static func decode<Buffer: Collection>(
bytes: Buffer,
options: DecodingOptions = []
) throws -> [UInt8] where Buffer.Element == UInt8 {
guard bytes.count > 0 else {
return []
}
Expand All @@ -104,7 +122,11 @@ public enum Base32 {
let outputLength = ((input.count + 7) / 8) * 5

return try [UInt8](unsafeUninitializedCapacity: outputLength) { output, length in
length = try Self._decode(from: input, into: output)
if options.contains(.allowNullCharacters) {
length = try Self._decode(from: input[...], into: output[...])
} else {
length = try Self._strictDecode(from: input, into: output)
}
}
}

Expand Down Expand Up @@ -153,27 +175,98 @@ extension Base32 {
/* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
]

private static let strictDecodeTable: [UInt8] = [
/* 00 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 08 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 10 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 18 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 20 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 28 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 30 */ 0x80, 0x80, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
/* 38 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0xC0, 0x80, 0x80,
/* 40 */ 0x80, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
/* 48 */ 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
/* 50 */ 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16,
/* 58 */ 0x17, 0x18, 0x19, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 60 */ 0x80, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
/* 68 */ 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
/* 60 */ 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16,
/* 68 */ 0x17, 0x18, 0x19, 0x80, 0x80, 0x80, 0x80, 0x80,

/* 80 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 88 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 90 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* 98 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* A0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* A8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* B0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* B8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* C0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* C8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* D0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* D8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* E0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* E8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* F0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
/* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
]

private static let encodeTable: [UInt8] = [
/* 00 */ 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48,
/* 08 */ 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50,
/* 10 */ 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58,
/* 18 */ 0x59, 0x5A, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
]

private static func _decode(from input: UnsafeBufferPointer<UInt8>, into output: UnsafeMutableBufferPointer<UInt8>) throws -> Int {
/// Decode Base32 assuming there are no null characters
private static func _strictDecode(from input: UnsafeBufferPointer<UInt8>, into output: UnsafeMutableBufferPointer<UInt8>) throws -> Int {
guard input.count != 0 else { return 0 }

var bitsLeft = 0
var buffer: UInt32 = 0
var outputIndex = 0
// work out how many blocks can go through the fast path. Last block
// should be passed to the slow path
let inputMinusLastBlock = (input.count - 1) & ~0x7
var i = 0
loop: while i < input.count {
while i < inputMinusLastBlock {
let v1 = self.strictDecodeTable[Int(input[i])]
let v2 = self.strictDecodeTable[Int(input[i + 1])]
let v3 = self.strictDecodeTable[Int(input[i + 2])]
let v4 = self.strictDecodeTable[Int(input[i + 3])]
let v5 = self.strictDecodeTable[Int(input[i + 4])]
let v6 = self.strictDecodeTable[Int(input[i + 5])]
let v7 = self.strictDecodeTable[Int(input[i + 6])]
let v8 = self.strictDecodeTable[Int(input[i + 7])]
let vCombined = v1 | v2 | v3 | v4 | v5 | v6 | v7 | v8
if (vCombined & ~0x1F) != 0 {
throw DecodingError.invalidCharacter
}
i += 8
output[outputIndex] = (v1 << 3) | (v2 >> 2)
output[outputIndex + 1] = (v2 << 6) | (v3 << 1) | (v4 >> 4)
output[outputIndex + 2] = (v4 << 4) | (v5 >> 1)
output[outputIndex + 3] = (v5 << 7) | (v6 << 2) | (v7 >> 3)
output[outputIndex + 4] = (v7 << 5) | v8
outputIndex += 5
}

return try self._decode(from: input[i...], into: output[outputIndex...])
}

/// Decode Base32 with the possibility of null characters or padding
private static func _decode(from input: UnsafeBufferPointer<UInt8>.SubSequence, into output: UnsafeMutableBufferPointer<UInt8>.SubSequence) throws -> Int {
guard input.count != 0 else { return output.startIndex }
var output = output
var bitsLeft = 0
var buffer: UInt32 = 0
var outputIndex = output.startIndex
var i = input.startIndex
loop: while i < input.endIndex {
let index = Int(input[i])
i += 1
let v = self.decodeTable[index]
switch v {
case 0x80:
throw DecodingError.invalidCharacter(UInt8(index))
throw DecodingError.invalidCharacter
case 0x40:
continue
case 0xC0:
Expand All @@ -191,9 +284,9 @@ extension Base32 {
}
}
// Any characters left should be padding
while i < input.count {
while i < input.endIndex {
let index = Int(input[i])
guard self.decodeTable[index] == 0xC0 else { throw DecodingError.invalidCharacter(UInt8(index)) }
guard self.decodeTable[index] == 0xC0 else { throw DecodingError.invalidCharacter }
i += 1
}
return outputIndex
Expand Down
42 changes: 40 additions & 2 deletions Tests/ExtrasBase64Tests/Base32Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,29 @@ class Base32Tests: XCTestCase {
XCTAssertEqual(decoded, expected)
}

func testBase32DecodingWithNullCharacters() {
let base32 = """
AAAQEAYEAUDAOCAJBIFQYDIOB4IBCEQTCQKRMFYYDENBWHA5D
YPSAIJCEMSCKJRHFAUSUKZMFUXC6MBRGIZTINJWG44DSOR3HQ
6T4P2AIFBEGRCFIZDUQSKKJNGE2TSPKBIVEU2UKVLFOWCZLJN
VYXK6L5QGCYTDMRSWMZ3INFVGW3DNNZXXA4LSON2HK5TXPB4X
U634PV7H7AEBQKBYJBMGQ6EITCULRSGY5D4QSGJJHFEVS2LZR
GM2TOOJ3HU7UCQ2FI5EUWTKPKFJVKV2ZLNOV6YLDMVTWS23NN
5YXG5LXPF5X274BQOCYPCMLRWHZDE4VS6MZXHM7UGR2LJ5JVO
W27MNTWW33TO55X7A4HROHZHF43T6R2PK5PWO33XP6DY7F47U
6X3PP6HZ7L57Z7P674
"""

let expected = Array(UInt8(0) ... UInt8(255))
var decoded: [UInt8]?
XCTAssertNoThrow(decoded = try Base32.decode(bytes: base32.utf8, options: .allowNullCharacters))
XCTAssertEqual(decoded, expected)
XCTAssertThrowsError(decoded = try Base32.decode(bytes: base32.utf8)) { _ in }
}

func testBase32DecodingWithPoop() {
XCTAssertThrowsError(_ = try Base32.decode(bytes: "💩".utf8)) { error in
XCTAssertEqual(error as? Base32.DecodingError, .invalidCharacter(240))
XCTAssertEqual(error as? Base32.DecodingError, .invalidCharacter)
}
}

Expand Down Expand Up @@ -103,12 +123,30 @@ class Base32Tests: XCTestCase {
}

func testBase32EncodeFoobarWithPadding() {
XCTAssertEqual(String(base32Encoding: "".utf8), "")
XCTAssertEqual(String(base32Encoding: "f".utf8), "MY======")
XCTAssertEqual(String(base32Encoding: "fo".utf8), "MZXQ====")
XCTAssertEqual(String(base32Encoding: "foo".utf8), "MZXW6===")
XCTAssertEqual(String(base32Encoding: "foob".utf8), "MZXW6YQ=")
XCTAssertEqual(String(base32Encoding: "fooba".utf8), "MZXW6YTB")
XCTAssertEqual(String(base32Encoding: "foobar".utf8), "MZXW6YTBOI======")
}

func testBase32DecodeFoobar() {
XCTAssertEqual(try "".base32decoded(), .init("".utf8))
XCTAssertEqual(try "MY".base32decoded(), .init("f".utf8))
XCTAssertEqual(try "MZXQ".base32decoded(), .init("fo".utf8))
XCTAssertEqual(try "MZXW6".base32decoded(), .init("foo".utf8))
XCTAssertEqual(try "MZXW6YQ".base32decoded(), .init("foob".utf8))
XCTAssertEqual(try "MZXW6YTB".base32decoded(), .init("fooba".utf8))
XCTAssertEqual(try "MZXW6YTBOI".base32decoded(), .init("foobar".utf8))
}

func testBase32DecodeFoobarWithPadding() {
XCTAssertEqual(try "MY======".base32decoded(), .init("f".utf8))
XCTAssertEqual(try "MZXQ====".base32decoded(), .init("fo".utf8))
XCTAssertEqual(try "MZXW6===".base32decoded(), .init("foo".utf8))
XCTAssertEqual(try "MZXW6YQ=".base32decoded(), .init("foob".utf8))
XCTAssertEqual(try "MZXW6YTB".base32decoded(), .init("fooba".utf8))
XCTAssertEqual(try "MZXW6YTBOI======".base32decoded(), .init("foobar".utf8))
}
}
Loading