Skip to content

Api Gateway authorizer improvements #827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/http-axum-apigw-authorizer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "http-axum-apigw-authorizer"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.7"
lambda_http = { path = "../../lambda-http" }
lambda_runtime = { path = "../../lambda-runtime" }
serde = "1.0.196"
serde_json = "1.0"
tokio = { version = "1", features = ["macros"] }
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] }
13 changes: 13 additions & 0 deletions examples/http-axum-apigw-authorizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Axum example that integrates with Api Gateway authorizers

This example shows how to extract information from the Api Gateway Request Authorizer in an Axum handler.

## Build & Deploy

1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation)
2. Build the function with `cargo lambda build --release`
3. Deploy the function to AWS Lambda with `cargo lambda deploy --iam-role YOUR_ROLE`

## Build for ARM 64

Build the function with `cargo lambda build --release --arm64`
80 changes: 80 additions & 0 deletions examples/http-axum-apigw-authorizer/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use axum::{
async_trait,
extract::{FromRequest, Request},
http::StatusCode,
response::Json,
routing::get,
Router,
};
use lambda_http::{run, Error, RequestExt};
use serde_json::{json, Value};
use std::{collections::HashMap, env::set_var};

struct AuthorizerField(String);
struct AuthorizerFields(HashMap<String, serde_json::Value>);

#[async_trait]
impl<S> FromRequest<S> for AuthorizerField
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
req.request_context_ref()
.and_then(|r| r.authorizer())
.and_then(|a| a.fields.get("field_name"))
.and_then(|f| f.as_str())
.map(|v| Self(v.to_string()))
.ok_or_else(|| (StatusCode::BAD_REQUEST, "`field_name` authorizer field is missing"))
}
}

#[async_trait]
impl<S> FromRequest<S> for AuthorizerFields
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
req.request_context_ref()
.and_then(|r| r.authorizer())
.map(|a| Self(a.fields.clone()))
.ok_or_else(|| (StatusCode::BAD_REQUEST, "authorizer is missing"))
}
}

async fn extract_field(AuthorizerField(field): AuthorizerField) -> Json<Value> {
Json(json!({ "field extracted": field }))
}

async fn extract_all_fields(AuthorizerFields(fields): AuthorizerFields) -> Json<Value> {
Json(json!({ "authorizer fields": fields }))
}

#[tokio::main]
async fn main() -> Result<(), Error> {
// If you use API Gateway stages, the Rust Runtime will include the stage name
// as part of the path that your application receives.
// Setting the following environment variable, you can remove the stage from the path.
// This variable only applies to API Gateway stages,
// you can remove it if you don't use them.
// i.e with: `GET /test-stage/todo/id/123` without: `GET /todo/id/123`
set_var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH", "true");

// required to enable CloudWatch error logging by the runtime
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
// disable printing the name of the module in every log line.
.with_target(false)
// disabling time is handy because CloudWatch will add the ingestion time.
.without_time()
.init();

let app = Router::new()
.route("/extract-field", get(extract_field))
.route("/extract-all-fields", get(extract_all_fields));

run(app).await
}
133 changes: 80 additions & 53 deletions lambda-events/src/event/apigw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@ use crate::custom_serde::{
use crate::encodings::Body;
use http::{HeaderMap, Method};
use query_map::QueryMap;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde::{de::DeserializeOwned, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use std::collections::HashMap;

/// `ApiGatewayProxyRequest` contains data coming from the API Gateway proxy
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayProxyRequest<T1 = Value>
where
T1: DeserializeOwned + Default,
T1: Serialize,
{
pub struct ApiGatewayProxyRequest {
/// The resource path defined in API Gateway
#[serde(default)]
pub resource: Option<String>,
Expand All @@ -44,7 +39,7 @@ where
#[serde(default)]
pub stage_variables: HashMap<String, String>,
#[serde(bound = "")]
pub request_context: ApiGatewayProxyRequestContext<T1>,
pub request_context: ApiGatewayProxyRequestContext,
#[serde(default)]
pub body: Option<String>,
#[serde(default, deserialize_with = "deserialize_nullish_boolean")]
Expand Down Expand Up @@ -72,11 +67,7 @@ pub struct ApiGatewayProxyResponse {
/// Lambda function. It also includes Cognito identity information for the caller.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayProxyRequestContext<T1 = Value>
where
T1: DeserializeOwned,
T1: Serialize,
{
pub struct ApiGatewayProxyRequestContext {
#[serde(default)]
pub account_id: Option<String>,
#[serde(default)]
Expand All @@ -99,10 +90,13 @@ where
pub resource_path: Option<String>,
#[serde(default)]
pub path: Option<String>,
#[serde(deserialize_with = "deserialize_lambda_map")]
#[serde(default)]
#[serde(bound = "")]
pub authorizer: HashMap<String, T1>,
#[serde(
default,
deserialize_with = "deserialize_authorizer_fields",
serialize_with = "serialize_authorizer_fields",
skip_serializing_if = "ApiGatewayRequestAuthorizer::is_empty"
)]
pub authorizer: ApiGatewayRequestAuthorizer,
#[serde(with = "http_method")]
pub http_method: Method,
#[serde(default)]
Expand Down Expand Up @@ -168,11 +162,7 @@ pub struct ApiGatewayV2httpRequest {
/// `ApiGatewayV2httpRequestContext` contains the information to identify the AWS account and resources invoking the Lambda function.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayV2httpRequestContext<T1 = Value>
where
T1: DeserializeOwned,
T1: Serialize,
{
pub struct ApiGatewayV2httpRequestContext {
#[serde(default)]
pub route_key: Option<String>,
#[serde(default)]
Expand All @@ -181,9 +171,9 @@ where
pub stage: Option<String>,
#[serde(default)]
pub request_id: Option<String>,
#[serde(bound = "", default)]
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub authorizer: Option<ApiGatewayV2httpRequestContextAuthorizerDescription<T1>>,
pub authorizer: Option<ApiGatewayRequestAuthorizer>,
/// The API Gateway HTTP API Id
#[serde(default)]
#[serde(rename = "apiId")]
Expand All @@ -203,19 +193,17 @@ where

/// `ApiGatewayV2httpRequestContextAuthorizerDescription` contains authorizer information for the request context.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayV2httpRequestContextAuthorizerDescription<T1 = Value>
where
T1: DeserializeOwned,
T1: Serialize,
{
pub struct ApiGatewayV2httpRequestContextAuthorizerDescription {
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt: Option<ApiGatewayV2httpRequestContextAuthorizerJwtDescription>,
#[serde(deserialize_with = "deserialize_lambda_map")]
#[serde(default)]
#[serde(bound = "")]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub lambda: HashMap<String, T1>,
#[serde(
bound = "",
rename = "lambda",
default,
skip_serializing_if = "HashMap::is_empty",
deserialize_with = "deserialize_lambda_map"
)]
pub fields: HashMap<String, Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iam: Option<ApiGatewayV2httpRequestContextAuthorizerIamDescription>,
}
Expand Down Expand Up @@ -332,13 +320,7 @@ pub struct ApiGatewayRequestIdentity {
/// `ApiGatewayWebsocketProxyRequest` contains data coming from the API Gateway proxy
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayWebsocketProxyRequest<T1 = Value, T2 = Value>
where
T1: DeserializeOwned + Default,
T1: Serialize,
T2: DeserializeOwned + Default,
T2: Serialize,
{
pub struct ApiGatewayWebsocketProxyRequest {
/// The resource path defined in API Gateway
#[serde(default)]
pub resource: Option<String>,
Expand Down Expand Up @@ -367,7 +349,7 @@ where
#[serde(default)]
pub stage_variables: HashMap<String, String>,
#[serde(bound = "")]
pub request_context: ApiGatewayWebsocketProxyRequestContext<T1, T2>,
pub request_context: ApiGatewayWebsocketProxyRequestContext,
#[serde(default)]
pub body: Option<String>,
#[serde(default, deserialize_with = "deserialize_nullish_boolean")]
Expand All @@ -379,13 +361,7 @@ where
/// Cognito identity information for the caller.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayWebsocketProxyRequestContext<T1 = Value, T2 = Value>
where
T1: DeserializeOwned,
T1: Serialize,
T2: DeserializeOwned,
T2: Serialize,
{
pub struct ApiGatewayWebsocketProxyRequestContext {
#[serde(default)]
pub account_id: Option<String>,
#[serde(default)]
Expand All @@ -398,8 +374,13 @@ where
pub identity: ApiGatewayRequestIdentity,
#[serde(default)]
pub resource_path: Option<String>,
#[serde(bound = "")]
pub authorizer: Option<T1>,
#[serde(
default,
deserialize_with = "deserialize_authorizer_fields",
serialize_with = "serialize_authorizer_fields",
skip_serializing_if = "ApiGatewayRequestAuthorizer::is_empty"
)]
pub authorizer: ApiGatewayRequestAuthorizer,
#[serde(deserialize_with = "http_method::deserialize_optional")]
#[serde(serialize_with = "http_method::serialize_optional")]
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -425,7 +406,7 @@ where
#[serde(default)]
pub message_direction: Option<String>,
#[serde(bound = "")]
pub message_id: Option<T2>,
pub message_id: Option<String>,
#[serde(default)]
pub request_time: Option<String>,
pub request_time_epoch: i64,
Expand Down Expand Up @@ -768,6 +749,40 @@ fn default_http_method() -> Method {
Method::GET
}

/// `ApiGatewayRequestAuthorizer` is a type alias for `ApiGatewayV2httpRequestContextAuthorizerDescription`.
/// This type is used by all events that receive request authorizer information.
pub type ApiGatewayRequestAuthorizer = ApiGatewayV2httpRequestContextAuthorizerDescription;

impl ApiGatewayRequestAuthorizer {
fn is_empty(&self) -> bool {
self.fields.is_empty()
}
}

fn deserialize_authorizer_fields<'de, D>(deserializer: D) -> Result<ApiGatewayRequestAuthorizer, D::Error>
where
D: Deserializer<'de>,
{
let fields: Option<HashMap<String, Value>> = Option::deserialize(deserializer)?;
let mut authorizer = ApiGatewayRequestAuthorizer::default();
if let Some(fields) = fields {
authorizer.fields = fields;
}

Ok(authorizer)
}

pub fn serialize_authorizer_fields<S: Serializer>(
authorizer: &ApiGatewayRequestAuthorizer,
ser: S,
) -> Result<S::Ok, S::Error> {
let mut map = ser.serialize_map(Some(authorizer.fields.len()))?;
for (k, v) in &authorizer.fields {
map.serialize_entry(k, v)?;
}
map.end()
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -991,4 +1006,16 @@ mod test {
let reparsed: ApiGatewayProxyRequest = serde_json::from_slice(output.as_bytes()).unwrap();
assert_eq!(parsed, reparsed);
}

#[test]
#[cfg(feature = "apigw")]
fn example_apigw_request_authorizer_fields() {
let data = include_bytes!("../../fixtures/example-apigw-request.json");
let parsed: ApiGatewayProxyRequest = serde_json::from_slice(data).unwrap();

let fields = parsed.request_context.authorizer.fields;
assert_eq!(Some("admin"), fields.get("principalId").unwrap().as_str());
assert_eq!(Some(1), fields.get("clientId").unwrap().as_u64());
assert_eq!(Some("Exata"), fields.get("clientName").unwrap().as_str());
}
}
Loading