Skip to content

[CDRIVER-4454] Azure auto-KMS testing #1104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,806 changes: 3,806 additions & 0 deletions build/bottle.py

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions build/fake_azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

import sys
import time
import json
import traceback
import functools
import bottle
from bottle import Bottle, HTTPResponse, request

imds = Bottle(autojson=True)
"""An Azure IMDS server"""

from typing import TYPE_CHECKING, Any, Callable, Iterable, overload

if TYPE_CHECKING:
from typing import Protocol

class _RequestParams(Protocol):

def __getitem__(self, key: str) -> str:
...

@overload
def get(self, key: str) -> str | None:
...

@overload
def get(self, key: str, default: str) -> str:
...

class _HeadersDict(dict[str, str]):

def raw(self, key: str) -> bytes | None:
...

class _Request(Protocol):
query: _RequestParams
params: _RequestParams
headers: _HeadersDict

request: _Request


def parse_qs(qs: str) -> dict[str, str]:
return dict(bottle._parse_qsl(qs)) # type: ignore


def require(cond: bool, message: str):
if not cond:
print(f'REQUIREMENT FAILED: {message}')
raise bottle.HTTPError(400, message)


_HandlerFuncT = Callable[
[],
'None|str|bytes|dict[str, Any]|bottle.BaseResponse|Iterable[bytes|str]']


def handle_asserts(fn: _HandlerFuncT) -> _HandlerFuncT:

@functools.wraps(fn)
def wrapped():
try:
return fn()
except AssertionError as e:
traceback.print_exc()
return bottle.HTTPResponse(status=400,
body=json.dumps({'error':
list(e.args)}))

return wrapped


def test_flags() -> dict[str, str]:
return parse_qs(request.headers.get('X-MongoDB-HTTP-TestParams', ''))


def maybe_pause():
pause = int(test_flags().get('pause', '0'))
if pause:
print(f'Pausing for {pause} seconds')
time.sleep(pause)


@imds.get('/metadata/identity/oauth2/token')
@handle_asserts
def get_oauth2_token():
api_version = request.query['api-version']
assert api_version == '2018-02-01', 'Only api-version=2018-02-01 is supported'
resource = request.query['resource']
assert resource == 'https://vault.azure.net', 'Only https://vault.azure.net is supported'

flags = test_flags()
maybe_pause()

case = flags.get('case')
print('Case is:', case)
if case == '404':
return HTTPResponse(status=404)

if case == '500':
return HTTPResponse(status=500)

if case == 'bad-json':
return b'{"key": }'

if case == 'empty-json':
return b'{}'

if case == 'giant':
return _gen_giant()

if case == 'slow':
return _slow()

assert case is None or case == '', f'Unknown HTTP test case "{case}"'

return {
'access_token': 'magic-cookie',
'expires_in': '60',
'token_type': 'Bearer',
'resource': 'https://vault.azure.net',
}


def _gen_giant() -> Iterable[bytes]:
yield b'{ "item": ['
for _ in range(1024 * 256):
yield (b'null, null, null, null, null, null, null, null, null, null, '
b'null, null, null, null, null, null, null, null, null, null, '
b'null, null, null, null, null, null, null, null, null, null, '
b'null, null, null, null, null, null, null, null, null, null, ')
yield b' null ] }'
yield b'\n'


def _slow() -> Iterable[bytes]:
yield b'{ "item": ['
for _ in range(1000):
yield b'null, '
time.sleep(1)
yield b' null ] }'


if __name__ == '__main__':
print(f'RECOMMENDED: Run this script using bottle.py in the same '
f'directory (e.g. [{sys.executable} bottle.py fake_azure:imds])')
imds.run()
174 changes: 89 additions & 85 deletions src/libmongoc/doc/errors.rst

Large diffs are not rendered by default.

87 changes: 76 additions & 11 deletions src/libmongoc/src/mongoc/mcd-azure.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,42 @@ static const char *const DEFAULT_METADATA_PATH =
"&resource=https%3A%2F%2Fvault.azure.net";

void
mcd_azure_imds_request_init (mcd_azure_imds_request *req)
mcd_azure_imds_request_init (mcd_azure_imds_request *req,
const char *const opt_imds_host,
int opt_port,
const char *const opt_extra_headers)
{
BSON_ASSERT_PARAM (req);
_mongoc_http_request_init (&req->req);
// The HTTP host of the IMDS server
req->req.host = "169.254.169.254";
req->req.port = 80;
req->req.host = req->_owned_host =
bson_strdup (opt_imds_host ? opt_imds_host : "169.254.169.254");
if (opt_port) {
req->req.port = opt_port;
} else {
req->req.port = 80;
}
// No body
req->req.body = "";
// We GET
req->req.method = "GET";
// 'Metadata: true' is required
req->req.extra_headers = "Metadata: true\r\n"
"Accept: application/json\r\n";
req->req.extra_headers = req->_owned_headers =
bson_strdup_printf ("Metadata: true\r\n"
"Accept: application/json\r\n%s",
opt_extra_headers ? opt_extra_headers : "");
// The default path is suitable. In the future, we may want to add query
// parameters to disambiguate a managed identity.
req->req.path = bson_strdup (DEFAULT_METADATA_PATH);
req->req.path = req->_owned_path = bson_strdup (DEFAULT_METADATA_PATH);
}

void
mcd_azure_imds_request_destroy (mcd_azure_imds_request *req)
{
BSON_ASSERT_PARAM (req);
bson_free ((void *) req->req.path);
bson_free (req->_owned_path);
bson_free (req->_owned_host);
bson_free (req->_owned_headers);
*req = (mcd_azure_imds_request){0};
}

Expand Down Expand Up @@ -97,8 +109,8 @@ mcd_azure_access_token_try_init_from_json_str (mcd_azure_access_token *out,
if (!(access_token && resource && token_type && expires_in_str)) {
bson_set_error (
error,
MONGOC_ERROR_PROTOCOL_ERROR,
64,
MONGOC_ERROR_AZURE,
MONGOC_ERROR_AZURE_BAD_JSON,
"One or more required JSON properties are missing/invalid: data: %.*s",
len,
json);
Expand All @@ -118,8 +130,8 @@ mcd_azure_access_token_try_init_from_json_str (mcd_azure_access_token *out,
// Did not parse the entire string. Bad
bson_set_error (
error,
MONGOC_ERROR_PROTOCOL,
65,
MONGOC_ERROR_AZURE,
MONGOC_ERROR_AZURE_BAD_JSON,
"Invalid 'expires_in' string \"%.*s\" from IMDS server",
expires_in_len,
expires_in_str);
Expand All @@ -144,3 +156,56 @@ mcd_azure_access_token_destroy (mcd_azure_access_token *c)
c->resource = NULL;
c->token_type = NULL;
}


bool
mcd_azure_access_token_from_imds (mcd_azure_access_token *const out,
const char *const opt_imds_host,
int opt_port,
const char *opt_extra_headers,
bson_error_t *error)
{
BSON_ASSERT_PARAM (out);

bool okay = false;

// Clear the output
*out = (mcd_azure_access_token){0};

mongoc_http_response_t resp;
_mongoc_http_response_init (&resp);

mcd_azure_imds_request req = {0};
mcd_azure_imds_request_init (
&req, opt_imds_host, opt_port, opt_extra_headers);

if (!_mongoc_http_send (&req.req, 3 * 1000, false, NULL, &resp, error)) {
_mongoc_http_response_cleanup (&resp);
goto fail;
}

// We only accept an HTTP 200 as a success
if (resp.status != 200) {
bson_set_error (error,
MONGOC_ERROR_AZURE,
MONGOC_ERROR_AZURE_HTTP,
"Error from Azure IMDS server while looking for "
"Managed Identity access token: %.*s",
resp.body_len,
resp.body);
goto fail;
}

// Parse the token from the response JSON
if (!mcd_azure_access_token_try_init_from_json_str (
out, resp.body, resp.body_len, error)) {
goto fail;
}

okay = true;

fail:
mcd_azure_imds_request_destroy (&req);
_mongoc_http_response_cleanup (&resp);
return okay;
}
34 changes: 33 additions & 1 deletion src/libmongoc/src/mongoc/mcd-azure.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,27 @@ mcd_azure_access_token_destroy (mcd_azure_access_token *token);
typedef struct mcd_azure_imds_request {
/// The underlying HTTP request object to be sent
mongoc_http_request_t req;
char *_owned_path;
char *_owned_host;
char *_owned_headers;
} mcd_azure_imds_request;

/**
* @brief Initialize a new IMDS HTTP request
*
* @param out The object to initialize
* @param opt_imds_host (Optional) the IP host of the IMDS server
* @param opt_port (Optional) The port of the IMDS HTTP server (default is 80)
* @param opt_extra_headers (Optional) Set extra HTTP headers for the request
*
* @note the request must later be destroyed with mcd_azure_imds_request_destroy
* @note Currently only supports the vault.azure.net resource
*/
void
mcd_azure_imds_request_init (mcd_azure_imds_request *out);
mcd_azure_imds_request_init (mcd_azure_imds_request *out,
const char *opt_imds_host,
int opt_port,
const char *opt_extra_headers);

/**
* @brief Destroy an IMDS request created with mcd_azure_imds_request_init()
Expand All @@ -95,4 +105,26 @@ mcd_azure_imds_request_init (mcd_azure_imds_request *out);
void
mcd_azure_imds_request_destroy (mcd_azure_imds_request *req);

/**
* @brief Attempt to obtain a new OAuth2 access token from an Azure IMDS HTTP
* server.
*
* @param out The output parameter for the obtained token. Must later be
* destroyed
* @param opt_imds_host (Optional) Override the IP host of the IMDS server
* @param opt_port (Optional) The port of the IMDS HTTP server (default is 80)
* @param opt_extra_headers (Optional) Set extra HTTP headers for the request
* @param error Output parameter for errors
* @retval true Upon success
* @retval false Otherwise. Sets an error via `error`
*
* @note Currently only supports the vault.azure.net resource
*/
bool
mcd_azure_access_token_from_imds (mcd_azure_access_token *out,
const char *opt_imds_host,
int opt_port,
const char *opt_extra_headers,
bson_error_t *error);

#endif // MCD_AZURE_H_INCLUDED
Loading