diff --git a/packages/pg/lib/sasl.js b/packages/pg/lib/sasl.js index 22abf5c4a..c61804750 100644 --- a/packages/pg/lib/sasl.js +++ b/packages/pg/lib/sasl.js @@ -20,19 +20,27 @@ function continueSession(session, password, serverData) { if (session.message !== 'SASLInitialResponse') { throw new Error('SASL: Last message was not SASLInitialResponse') } + if (typeof password !== 'string') { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: client password must be a string') + } + if (typeof serverData !== 'string') { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: serverData must be a string') + } - const sv = extractVariablesFromFirstServerMessage(serverData) + const sv = parseServerFirstMessage(serverData) if (!sv.nonce.startsWith(session.clientNonce)) { throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: server nonce does not start with client nonce') + } else if (sv.nonce.length === session.clientNonce.length) { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: server nonce is too short') } var saltBytes = Buffer.from(sv.salt, 'base64') var saltedPassword = Hi(password, saltBytes, sv.iteration) - var clientKey = createHMAC(saltedPassword, 'Client Key') - var storedKey = crypto.createHash('sha256').update(clientKey).digest() + var clientKey = hmacSha256(saltedPassword, 'Client Key') + var storedKey = sha256(clientKey) var clientFirstMessageBare = 'n=*,r=' + session.clientNonce var serverFirstMessage = 'r=' + sv.nonce + ',s=' + sv.salt + ',i=' + sv.iteration @@ -41,12 +49,12 @@ function continueSession(session, password, serverData) { var authMessage = clientFirstMessageBare + ',' + serverFirstMessage + ',' + clientFinalMessageWithoutProof - var clientSignature = createHMAC(storedKey, authMessage) + var clientSignature = hmacSha256(storedKey, authMessage) var clientProofBytes = xorBuffers(clientKey, clientSignature) var clientProof = clientProofBytes.toString('base64') - var serverKey = createHMAC(saltedPassword, 'Server Key') - var serverSignatureBytes = createHMAC(serverKey, authMessage) + var serverKey = hmacSha256(saltedPassword, 'Server Key') + var serverSignatureBytes = hmacSha256(serverKey, authMessage) session.message = 'SASLResponse' session.serverSignature = serverSignatureBytes.toString('base64') @@ -57,54 +65,87 @@ function finalizeSession(session, serverData) { if (session.message !== 'SASLResponse') { throw new Error('SASL: Last message was not SASLResponse') } + if (typeof serverData !== 'string') { + throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: serverData must be a string') + } - var serverSignature - - String(serverData) - .split(',') - .forEach(function (part) { - switch (part[0]) { - case 'v': - serverSignature = part.substr(2) - break - } - }) + const { serverSignature } = parseServerFinalMessage(serverData) if (serverSignature !== session.serverSignature) { throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature does not match') } } -function extractVariablesFromFirstServerMessage(data) { - var nonce, salt, iteration - - String(data) - .split(',') - .forEach(function (part) { - switch (part[0]) { - case 'r': - nonce = part.substr(2) - break - case 's': - salt = part.substr(2) - break - case 'i': - iteration = parseInt(part.substr(2), 10) - break +/** + * printable = %x21-2B / %x2D-7E + * ;; Printable ASCII except ",". + * ;; Note that any "printable" is also + * ;; a valid "value". + */ +function isPrintableChars(text) { + if (typeof text !== 'string') { + throw new TypeError('SASL: text must be a string') + } + return text + .split('') + .map((_, i) => text.charCodeAt(i)) + .every((c) => (c >= 0x21 && c <= 0x2b) || (c >= 0x2d && c <= 0x7e)) +} + +/** + * base64-char = ALPHA / DIGIT / "/" / "+" + * + * base64-4 = 4base64-char + * + * base64-3 = 3base64-char "=" + * + * base64-2 = 2base64-char "==" + * + * base64 = *base64-4 [base64-3 / base64-2] + */ +function isBase64(text) { + return /^(?:[a-zA-Z0-9+/]{4})*(?:[a-zA-Z0-9+/]{2}==|[a-zA-Z0-9+/]{3}=)?$/.test(text) +} + +function parseAttributePairs(text) { + if (typeof text !== 'string') { + throw new TypeError('SASL: attribute pairs text must be a string') + } + + return new Map( + text.split(',').map((attrValue) => { + if (!/^.=/.test(attrValue)) { + throw new Error('SASL: Invalid attribute pair entry') } + const name = attrValue[0] + const value = attrValue.substring(2) + return [name, value] }) + ) +} +function parseServerFirstMessage(data) { + const attrPairs = parseAttributePairs(data) + + const nonce = attrPairs.get('r') if (!nonce) { throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: nonce missing') + } else if (!isPrintableChars(nonce)) { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: nonce must only contain printable characters') } - + const salt = attrPairs.get('s') if (!salt) { throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: salt missing') + } else if (!isBase64(salt)) { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: salt must be base64') } - - if (!iteration) { + const iterationText = attrPairs.get('i') + if (!iterationText) { throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: iteration missing') + } else if (!/^[1-9][0-9]*$/.test(iterationText)) { + throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: invalid iteration count') } + const iteration = parseInt(iterationText, 10) return { nonce, @@ -113,31 +154,48 @@ function extractVariablesFromFirstServerMessage(data) { } } +function parseServerFinalMessage(serverData) { + const attrPairs = parseAttributePairs(serverData) + const serverSignature = attrPairs.get('v') + if (!serverSignature) { + throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature is missing') + } else if (!isBase64(serverSignature)) { + throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature must be base64') + } + return { + serverSignature, + } +} + function xorBuffers(a, b) { - if (!Buffer.isBuffer(a)) a = Buffer.from(a) - if (!Buffer.isBuffer(b)) b = Buffer.from(b) - var res = [] - if (a.length > b.length) { - for (var i = 0; i < b.length; i++) { - res.push(a[i] ^ b[i]) - } - } else { - for (var j = 0; j < a.length; j++) { - res.push(a[j] ^ b[j]) - } - } - return Buffer.from(res) + if (!Buffer.isBuffer(a)) { + throw new TypeError('first argument must be a Buffer') + } + if (!Buffer.isBuffer(b)) { + throw new TypeError('second argument must be a Buffer') + } + if (a.length !== b.length) { + throw new Error('Buffer lengths must match') + } + if (a.length === 0) { + throw new Error('Buffers cannot be empty') + } + return Buffer.from(a.map((_, i) => a[i] ^ b[i])) +} + +function sha256(text) { + return crypto.createHash('sha256').update(text).digest() } -function createHMAC(key, msg) { +function hmacSha256(key, msg) { return crypto.createHmac('sha256', key).update(msg).digest() } function Hi(password, saltBytes, iterations) { - var ui1 = createHMAC(password, Buffer.concat([saltBytes, Buffer.from([0, 0, 0, 1])])) + var ui1 = hmacSha256(password, Buffer.concat([saltBytes, Buffer.from([0, 0, 0, 1])])) var ui = ui1 for (var i = 0; i < iterations - 1; i++) { - ui1 = createHMAC(password, ui1) + ui1 = hmacSha256(password, ui1) ui = xorBuffers(ui, ui1) } diff --git a/packages/pg/test/test-helper.js b/packages/pg/test/test-helper.js index 319b8ee79..5999ea98f 100644 --- a/packages/pg/test/test-helper.js +++ b/packages/pg/test/test-helper.js @@ -111,16 +111,6 @@ assert.success = function (callback) { } } -assert.throws = function (offender) { - try { - offender() - } catch (e) { - assert.ok(e instanceof Error, 'Expected ' + offender + ' to throw instances of Error') - return - } - assert.ok(false, 'Expected ' + offender + ' to throw exception') -} - assert.lengthIs = function (actual, expectedLength) { assert.equal(actual.length, expectedLength) } diff --git a/packages/pg/test/unit/client/sasl-scram-tests.js b/packages/pg/test/unit/client/sasl-scram-tests.js index f60c8c4c9..e53448bdf 100644 --- a/packages/pg/test/unit/client/sasl-scram-tests.js +++ b/packages/pg/test/unit/client/sasl-scram-tests.js @@ -38,7 +38,7 @@ test('sasl/scram', function () { test('fails when last session message was not SASLInitialResponse', function () { assert.throws( function () { - sasl.continueSession({}) + sasl.continueSession({}, '', '') }, { message: 'SASL: Last message was not SASLInitialResponse', @@ -53,6 +53,7 @@ test('sasl/scram', function () { { message: 'SASLInitialResponse', }, + 'bad-password', 's=1,i=1' ) }, @@ -69,6 +70,7 @@ test('sasl/scram', function () { { message: 'SASLInitialResponse', }, + 'bad-password', 'r=1,i=1' ) }, @@ -85,7 +87,8 @@ test('sasl/scram', function () { { message: 'SASLInitialResponse', }, - 'r=1,s=1' + 'bad-password', + 'r=1,s=abcd' ) }, { @@ -102,7 +105,8 @@ test('sasl/scram', function () { message: 'SASLInitialResponse', clientNonce: '2', }, - 'r=1,s=1,i=1' + 'bad-password', + 'r=1,s=abcd,i=1' ) }, { @@ -117,12 +121,12 @@ test('sasl/scram', function () { clientNonce: 'a', } - sasl.continueSession(session, 'password', 'r=ab,s=x,i=1') + sasl.continueSession(session, 'password', 'r=ab,s=abcd,i=1') assert.equal(session.message, 'SASLResponse') - assert.equal(session.serverSignature, 'TtywIrpWDJ0tCSXM2mjkyiaa8iGZsZG7HllQxr8fYAo=') + assert.equal(session.serverSignature, 'jwt97IHWFn7FEqHykPTxsoQrKGOMXJl/PJyJ1JXTBKc=') - assert.equal(session.response, 'c=biws,r=ab,p=KAEPBUTjjofB0IM5UWcZApK1dSzFE0o5vnbWjBbvFHA=') + assert.equal(session.response, 'c=biws,r=ab,p=mU8grLfTjDrJer9ITsdHk0igMRDejG10EJPFbIBL3D0=') }) }) @@ -138,15 +142,32 @@ test('sasl/scram', function () { ) }) + test('fails when server signature is not valid base64', function () { + assert.throws( + function () { + sasl.finalizeSession( + { + message: 'SASLResponse', + serverSignature: 'abcd', + }, + 'v=x1' // Purposefully invalid base64 + ) + }, + { + message: 'SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature must be base64', + } + ) + }) + test('fails when server signature does not match', function () { assert.throws( function () { sasl.finalizeSession( { message: 'SASLResponse', - serverSignature: '3', + serverSignature: 'abcd', }, - 'v=4' + 'v=xyzq' ) }, { @@ -159,9 +180,9 @@ test('sasl/scram', function () { sasl.finalizeSession( { message: 'SASLResponse', - serverSignature: '5', + serverSignature: 'abcd', }, - 'v=5' + 'v=abcd' ) }) })