Skip to content

Fix broken security definition reference for OAuth2 #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class RestdocsOpenApiTaskTest : RestdocsOpenApiTaskTestBase() {

override fun thenSecurityDefinitionsFoundInOutputFile() {
with(JsonPath.parse(outputFolder.resolve("$outputFileNamePrefix.$format").readText())) {
then(read<String>("securityDefinitions.oauth2_accessCode.scopes.prod:r")).isEqualTo("Some text")
then(read<String>("securityDefinitions.oauth2_accessCode.type")).isEqualTo("oauth2")
then(read<String>("securityDefinitions.oauth2_accessCode.tokenUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2_accessCode.authorizationUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2_accessCode.flow")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.scopes.prod:r")).isEqualTo("Some text")
then(read<String>("securityDefinitions.oauth2.type")).isEqualTo("oauth2")
then(read<String>("securityDefinitions.oauth2.tokenUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.authorizationUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.flow")).isNotEmpty()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ open class Oauth2Configuration(
var flows: Array<String> = arrayOf(),
var scopes: Map<String, String> = mapOf()
) {
fun securitySchemeName(flow: String) = "oauth2_$flow"
fun securitySchemeName() = "oauth2"
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object OpenApi20Generator {

private const val API_KEY_SECURITY_NAME = "api_key"
private const val BASIC_SECURITY_NAME = "basic"
private const val OAUTH2_SECURITY_NAME = "oauth2"
private val PATH_PARAMETER_PATTERN = """\{([^/}]+)}""".toRegex()
internal fun generate(
resources: List<ResourceModel>,
Expand Down Expand Up @@ -323,14 +324,7 @@ object OpenApi20Generator {
val securityRequirements = firstModelForPathAndMethod.request.securityRequirements
if (securityRequirements != null) {
when (securityRequirements.type) {
SecurityType.OAUTH2 -> oauth2SecuritySchemeDefinition?.flows?.map {
addSecurity(
oauth2SecuritySchemeDefinition.securitySchemeName(it),
securityRequirements2ScopesList(
securityRequirements
)
)
}
SecurityType.OAUTH2 -> addSecurity(OAUTH2_SECURITY_NAME, securityRequirements2ScopesList(securityRequirements))
SecurityType.BASIC -> addSecurity(BASIC_SECURITY_NAME, null)
SecurityType.API_KEY -> addSecurity(API_KEY_SECURITY_NAME, null)
}
Expand Down Expand Up @@ -372,7 +366,7 @@ object OpenApi20Generator {
addScope(it, scopeAndDescriptions.getOrDefault(it, "No description"))
}
}
openApi.addSecurityDefinition(oauth2SecuritySchemeDefinition.securitySchemeName(flow), oauth2Definition)
openApi.addSecurityDefinition(oauth2SecuritySchemeDefinition.securitySchemeName(), oauth2Definition)
}
if (hasAnyOperationWithSecurityName(
openApi,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ class OpenApi20GeneratorTest {
val openapi = whenOpenApiObjectGenerated(api)

with(openapi.securityDefinitions) {
then(this.containsKey("oauth2_accessCode"))
then(this["oauth2_accessCode"])
then(this.containsKey("oauth2"))
then(this["oauth2"])
.isEqualToComparingFieldByField(
OAuth2Definition().accessCode("http://example.com/authorize", "http://example.com/token")
.apply { addScope("prod:r", "No description") }
Expand Down Expand Up @@ -356,12 +356,12 @@ class OpenApi20GeneratorTest {
then(productPath.get.operationId).isNotEmpty()
then(productPath.get.consumes).contains(successfulGetProductModel.request.contentType)

then(productPath.get.security).hasSize(2)
then(productPath.get.security).hasSize(1)

then(productPath.get.tags).containsOnly("tag1", "tag2")

val combined = productPath.get.security.reduce { map1, map2 -> map1 + map2 }
then(combined).containsOnlyKeys("oauth2_application", "oauth2_accessCode")
then(combined).containsOnlyKeys("oauth2")
then(combined.values).containsOnly(listOf("prod:r"))

then(successfulGetResponse).isNotNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ object OpenApi3Generator {
)
}
)
}.apply { addSecurityItemFromSecurityRequirements(firstModelForPathAndMethod.request.securityRequirements, oauth2SecuritySchemeDefinition) }
}.apply { addSecurityItemFromSecurityRequirements(firstModelForPathAndMethod.request.securityRequirements) }
}

private fun operationId(operationIds: List<String>): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ internal object SecuritySchemeGenerator {
private const val API_KEY_SECURITY_NAME = "api_key"
private const val BASIC_SECURITY_NAME = "basic"
private const val JWT_BEARER_SECURITY_NAME = "bearerAuthJWT"
private const val OAUTH2_SECURITY_NAME = "oauth2"

fun OpenAPI.addSecurityDefinitions(oauth2SecuritySchemeDefinition: Oauth2Configuration?) {
if (oauth2SecuritySchemeDefinition?.flows?.isNotEmpty() == true) {
val flows = OAuthFlows()
components.addSecuritySchemes(
"oauth2",
OAUTH2_SECURITY_NAME,
SecurityScheme().apply {
type = SecurityScheme.Type.OAUTH2
this.flows = flows
Expand Down Expand Up @@ -90,17 +91,10 @@ internal object SecuritySchemeGenerator {
}
}

fun Operation.addSecurityItemFromSecurityRequirements(securityRequirements: SecurityRequirements?, oauth2SecuritySchemeDefinition: Oauth2Configuration?) {
fun Operation.addSecurityItemFromSecurityRequirements(securityRequirements: SecurityRequirements?) {
if (securityRequirements != null) {
when (securityRequirements.type) {
SecurityType.OAUTH2 -> oauth2SecuritySchemeDefinition?.flows?.map {
addSecurityItem(
SecurityRequirement().addList(
oauth2SecuritySchemeDefinition.securitySchemeName(it),
securityRequirements2ScopesList(securityRequirements)
)
)
}
SecurityType.OAUTH2 -> addSecurityItem(SecurityRequirement().addList(OAUTH2_SECURITY_NAME, securityRequirements2ScopesList(securityRequirements)))
SecurityType.BASIC -> addSecurityItem(SecurityRequirement().addList(BASIC_SECURITY_NAME))
SecurityType.API_KEY -> addSecurityItem(SecurityRequirement().addList(API_KEY_SECURITY_NAME))
SecurityType.JWT_BEARER -> addSecurityItem(SecurityRequirement().addList(JWT_BEARER_SECURITY_NAME))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,7 @@ class OpenApi3GeneratorTest {
then(openApiJsonPathContext.read<Any>("$productGetByIdPath.responses.200.content.application/json.schema.\$ref")).isNotNull()
then(openApiJsonPathContext.read<Any>("$productGetByIdPath.responses.200.content.application/json.examples.test.value")).isNotNull()

then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2_clientCredentials").flatMap { it }).containsOnly("prod:r")
then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2_authorizationCode").flatMap { it }).containsOnly("prod:r")
then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2").flatMap { it }).containsOnly("prod:r")
}

private fun thenMultiplePathParametersExist() {
Expand Down