diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGenerator.kt index 8c953e7..f3407d6 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGenerator.kt @@ -1,5 +1,7 @@ package com.avsystem.justworks.core.gen +import com.avsystem.justworks.core.model.ApiKeyLocation +import com.avsystem.justworks.core.model.SecurityScheme import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.ContextParameter import com.squareup.kotlinpoet.ExperimentalKotlinPoetApi @@ -30,7 +32,7 @@ object ApiClientBaseGenerator { private const val SUCCESS_BODY = "successBody" private const val SERIALIZERS_MODULE_PARAM = "serializersModule" - fun generate(): FileSpec { + fun generate(securitySchemes: List): FileSpec { val t = TypeVariableName("T").copy(reified = true) return FileSpec @@ -39,7 +41,7 @@ object ApiClientBaseGenerator { .addFunction(buildMapToResult(t)) .addFunction(buildToResult(t)) .addFunction(buildToEmptyResult()) - .addType(buildApiClientBaseClass()) + .addType(buildApiClientBaseClass(securitySchemes)) .build() } @@ -103,26 +105,33 @@ object ApiClientBaseGenerator { .addStatement("return %L { Unit }", MAP_TO_RESULT) .build() - private fun buildApiClientBaseClass(): TypeSpec { + private fun buildApiClientBaseClass(securitySchemes: List): TypeSpec { val tokenType = LambdaTypeName.get(returnType = STRING) + val authParams = buildAuthConstructorParams(securitySchemes) - val constructor = FunSpec + val constructorBuilder = FunSpec .constructorBuilder() .addParameter(BASE_URL, STRING) - .addParameter(TOKEN, tokenType) - .build() + + val propertySpecs = mutableListOf() val baseUrlProp = PropertySpec .builder(BASE_URL, STRING) .initializer(BASE_URL) .addModifiers(KModifier.PROTECTED) .build() + propertySpecs.add(baseUrlProp) - val tokenProp = PropertySpec - .builder(TOKEN, tokenType) - .initializer(TOKEN) - .addModifiers(KModifier.PRIVATE) - .build() + for ((paramName, _) in authParams) { + constructorBuilder.addParameter(paramName, tokenType) + propertySpecs.add( + PropertySpec + .builder(paramName, tokenType) + .initializer(paramName) + .addModifiers(KModifier.PRIVATE) + .build(), + ) + } val clientProp = PropertySpec .builder(CLIENT, HTTP_CLIENT) @@ -135,32 +144,125 @@ object ApiClientBaseGenerator { .addStatement("$CLIENT.close()") .build() - return TypeSpec + val classBuilder = TypeSpec .classBuilder(API_CLIENT_BASE) .addModifiers(KModifier.ABSTRACT) .addSuperinterface(CLOSEABLE) - .primaryConstructor(constructor) - .addProperty(baseUrlProp) - .addProperty(tokenProp) + .primaryConstructor(constructorBuilder.build()) + + for (prop in propertySpecs) { + classBuilder.addProperty(prop) + } + + return classBuilder .addProperty(clientProp) .addFunction(closeFun) - .addFunction(buildApplyAuth()) + .addFunction(buildApplyAuth(securitySchemes)) .addFunction(buildSafeCall()) .addFunction(buildCreateHttpClient()) .build() } - private fun buildApplyAuth(): FunSpec = FunSpec - .builder(APPLY_AUTH) - .addModifiers(KModifier.PROTECTED) - .receiver(HTTP_REQUEST_BUILDER) - .beginControlFlow("%M", HEADERS_FUN) - .addStatement( - "append(%T.Authorization, %P)", - HTTP_HEADERS, - CodeBlock.of($$"Bearer ${'$'}{$$TOKEN()}"), - ).endControlFlow() - .build() + /** + * Builds the list of auth-related constructor parameter names based on security schemes. + * Returns pairs of (paramName, schemeType) for each scheme. + */ + internal fun buildAuthConstructorParams(securitySchemes: List): List> = + securitySchemes.flatMap { scheme -> + when (scheme) { + is SecurityScheme.Bearer -> { + val isSingleBearer = + securitySchemes.size == 1 && securitySchemes.first() is SecurityScheme.Bearer + + val paramName = if (isSingleBearer) TOKEN else "${scheme.name.toCamelCase()}Token" + listOf(paramName to scheme) + } + + is SecurityScheme.ApiKey -> { + listOf("${scheme.name.toCamelCase()}Key" to scheme) + } + + is SecurityScheme.Basic -> { + listOf( + "${scheme.name.toCamelCase()}Username" to scheme, + "${scheme.name.toCamelCase()}Password" to scheme, + ) + } + } + } + + private fun buildApplyAuth(securitySchemes: List): FunSpec { + val builder = FunSpec + .builder(APPLY_AUTH) + .addModifiers(KModifier.PROTECTED) + .receiver(HTTP_REQUEST_BUILDER) + + if (securitySchemes.isEmpty()) return builder.build() + + val headerSchemes = securitySchemes.filter { + it is SecurityScheme.Bearer || + it is SecurityScheme.Basic || + (it is SecurityScheme.ApiKey && it.location == ApiKeyLocation.HEADER) + } + val querySchemes = securitySchemes + .filterIsInstance() + .filter { it.location == ApiKeyLocation.QUERY } + + if (headerSchemes.isNotEmpty()) { + builder.beginControlFlow("%M", HEADERS_FUN) + for (scheme in headerSchemes) { + when (scheme) { + is SecurityScheme.Bearer -> { + val isSingleBearer = + securitySchemes.size == 1 && securitySchemes.first() is SecurityScheme.Bearer + + val paramName = if (isSingleBearer) TOKEN else "${scheme.name.toCamelCase()}Token" + builder.addStatement( + "append(%T.Authorization, %P)", + HTTP_HEADERS, + CodeBlock.of("Bearer \${$paramName()}"), + ) + } + + is SecurityScheme.Basic -> { + val usernameParam = "${scheme.name.toCamelCase()}Username" + val passwordParam = "${scheme.name.toCamelCase()}Password" + builder.addStatement( + "append(%T.Authorization, %P)", + HTTP_HEADERS, + CodeBlock.of( + "Basic \${%T.getEncoder().encodeToString(\"${'$'}{$usernameParam()}:${'$'}{$passwordParam()}\".toByteArray())}", + BASE64_CLASS, + ), + ) + } + + is SecurityScheme.ApiKey -> { + val paramName = "${scheme.name.toCamelCase()}Key" + builder.addStatement( + "append(%S, $paramName())", + scheme.parameterName, + ) + } + } + } + builder.endControlFlow() + } + + if (querySchemes.isNotEmpty()) { + builder.beginControlFlow("url") + for (scheme in querySchemes) { + val paramName = "${scheme.name.toCamelCase()}Key" + builder.addStatement( + "parameters.append(%S, $paramName())", + scheme.parameterName, + ) + } + builder.endControlFlow() + } + + return builder.build() + } private fun buildSafeCall(): FunSpec = FunSpec .builder(SAFE_CALL) diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/ClientGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/ClientGenerator.kt index e182e6f..75463bd 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/ClientGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/ClientGenerator.kt @@ -5,6 +5,7 @@ import com.avsystem.justworks.core.model.Endpoint import com.avsystem.justworks.core.model.HttpMethod import com.avsystem.justworks.core.model.Parameter import com.avsystem.justworks.core.model.ParameterLocation +import com.avsystem.justworks.core.model.SecurityScheme import com.avsystem.justworks.core.model.TypeRef import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.CodeBlock @@ -34,13 +35,16 @@ private const val API_SUFFIX = "Api" class ClientGenerator(private val apiPackage: String, private val modelPackage: String) { fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean = false): List { val grouped = spec.endpoints.groupBy { it.tags.firstOrNull() ?: DEFAULT_TAG } - return grouped.map { (tag, endpoints) -> generateClientFile(tag, endpoints, hasPolymorphicTypes) } + return grouped.map { (tag, endpoints) -> + generateClientFile(tag, endpoints, hasPolymorphicTypes, spec.securitySchemes) + } } private fun generateClientFile( tag: String, endpoints: List, hasPolymorphicTypes: Boolean = false, + securitySchemes: List, ): FileSpec { val className = ClassName(apiPackage, "${tag.toPascalCase()}$API_SUFFIX") @@ -52,12 +56,21 @@ class ClientGenerator(private val apiPackage: String, private val modelPackage: } val tokenType = LambdaTypeName.get(returnType = STRING) + val authParams = ApiClientBaseGenerator.buildAuthConstructorParams(securitySchemes) - val primaryConstructor = FunSpec + val constructorBuilder = FunSpec .constructorBuilder() .addParameter(BASE_URL, STRING) - .addParameter(TOKEN, tokenType) - .build() + + val classBuilder = TypeSpec + .classBuilder(className) + .superclass(API_CLIENT_BASE) + .addSuperclassConstructorParameter(BASE_URL) + + for ((paramName, _) in authParams) { + constructorBuilder.addParameter(paramName, tokenType) + classBuilder.addSuperclassConstructorParameter(paramName) + } val httpClientProperty = PropertySpec .builder(CLIENT, HTTP_CLIENT) @@ -65,12 +78,8 @@ class ClientGenerator(private val apiPackage: String, private val modelPackage: .initializer(clientInitializer) .build() - val classBuilder = TypeSpec - .classBuilder(className) - .superclass(API_CLIENT_BASE) - .addSuperclassConstructorParameter(BASE_URL) - .addSuperclassConstructorParameter(TOKEN) - .primaryConstructor(primaryConstructor) + classBuilder + .primaryConstructor(constructorBuilder.build()) .addProperty(httpClientProperty) classBuilder.addFunctions(endpoints.map(::generateEndpointFunction)) diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt index 10c6715..a29dc1e 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt @@ -14,7 +14,7 @@ object CodeGenerator { spec: ApiSpec, modelPackage: String, apiPackage: String, - outputDir: File + outputDir: File, ): Result { val modelFiles = ModelGenerator(modelPackage).generate(spec) modelFiles.forEach { it.writeTo(outputDir) } @@ -27,8 +27,10 @@ object CodeGenerator { return Result(modelFiles.size, clientFiles.size) } - fun generateSharedTypes(outputDir: File): Int { - val files = ApiResponseGenerator.generate() + ApiClientBaseGenerator.generate() + fun generateSharedTypes(outputDir: File, specs: List = emptyList()): Int { + val securitySchemes = specs.flatMap { it.securitySchemes } + + val files = ApiResponseGenerator.generate() + ApiClientBaseGenerator.generate(securitySchemes) files.forEach { it.writeTo(outputDir) } return files.size } diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/Names.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Names.kt index e2166db..dd95922 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/Names.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Names.kt @@ -82,6 +82,7 @@ val HTTP_SUCCESS = ClassName("com.avsystem.justworks", "HttpSuccess") // Kotlin stdlib // ============================================================================ +val BASE64_CLASS = ClassName("java.util", "Base64") val CLOSEABLE = ClassName("java.io", "Closeable") val IO_EXCEPTION = ClassName("java.io", "IOException") val HTTP_REQUEST_TIMEOUT_EXCEPTION = ClassName("io.ktor.client.plugins", "HttpRequestTimeoutException") diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/model/ApiSpec.kt b/core/src/main/kotlin/com/avsystem/justworks/core/model/ApiSpec.kt index fdcc056..f197bf4 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/model/ApiSpec.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/model/ApiSpec.kt @@ -7,12 +7,29 @@ package com.avsystem.justworks.core.model * code generators. Bridges the raw Swagger Parser OAS model and the generated * Kotlin client/model source files. */ +sealed interface SecurityScheme { + val name: String + + data class Bearer(override val name: String) : SecurityScheme + + data class ApiKey( + override val name: String, + val parameterName: String, + val location: ApiKeyLocation, + ) : SecurityScheme + + data class Basic(override val name: String) : SecurityScheme +} + +enum class ApiKeyLocation { HEADER, QUERY } + data class ApiSpec( val title: String, val version: String, val endpoints: List, val schemas: List, val enums: List, + val securitySchemes: List, ) data class Endpoint( diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt b/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt index 795396b..a613798 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt @@ -12,6 +12,7 @@ import arrow.core.raise.iorNel import arrow.core.raise.nullable import com.avsystem.justworks.core.Issue import com.avsystem.justworks.core.Warnings +import com.avsystem.justworks.core.model.ApiKeyLocation import com.avsystem.justworks.core.model.ApiSpec import com.avsystem.justworks.core.model.Discriminator import com.avsystem.justworks.core.model.Endpoint @@ -25,16 +26,21 @@ import com.avsystem.justworks.core.model.PropertyModel import com.avsystem.justworks.core.model.RequestBody import com.avsystem.justworks.core.model.Response import com.avsystem.justworks.core.model.SchemaModel +import com.avsystem.justworks.core.model.SecurityScheme import com.avsystem.justworks.core.model.TypeRef import com.avsystem.justworks.core.warn import io.swagger.parser.OpenAPIParser import io.swagger.v3.oas.models.OpenAPI import io.swagger.v3.oas.models.PathItem import io.swagger.v3.oas.models.media.Schema +import io.swagger.v3.oas.models.security.SecurityRequirement import io.swagger.v3.parser.core.models.ParseOptions import java.io.File import java.util.IdentityHashMap +import kotlin.apply +import kotlin.collections.map import io.swagger.v3.oas.models.parameters.Parameter as SwaggerParameter +import io.swagger.v3.oas.models.security.SecurityScheme as SwaggerSecurityScheme /** * Result of parsing an OpenAPI specification file. @@ -43,7 +49,7 @@ import io.swagger.v3.oas.models.parameters.Parameter as SwaggerParameter * ```kotlin * when (val result = SpecParser.parse(file)) { * is ParseResult.Success -> result.apiSpec - * is ParseResult.Failure -> handleErrors(result.error) + * is ParseResult.Failure -> handleErrors(result.errors) * } * ``` * @@ -114,6 +120,11 @@ object SpecParser { private fun OpenAPI.toApiSpec(): ApiSpec { val allSchemas = components?.schemas.orEmpty() + val securitySchemes = extractSecuritySchemes( + components?.securitySchemes.orEmpty(), + security.orEmpty(), + ) + val componentSchemaIdentity = ComponentSchemaIdentity(allSchemas.size).apply { allSchemas.forEach { (name, schema) -> this[schema] = name } } @@ -157,10 +168,47 @@ object SpecParser { endpoints = endpoints, schemas = schemaModels + syntheticModels, enums = enumModels, + securitySchemes = securitySchemes, ) } } + @OptIn(ExperimentalRaiseAccumulateApi::class) + context(_: Warnings) + private fun extractSecuritySchemes( + definitions: Map, + requirements: List, + ): List { + val referencedNames = requirements.flatMap { it.keys }.toSet() + return referencedNames.mapNotNull { name -> + definitions[name]?.toSecurityScheme(name) + ?: warn("Security requirement references undefined scheme '$name'") + } + } + + context(_: Warnings) + private fun SwaggerSecurityScheme.toSecurityScheme(name: String): SecurityScheme? = when (type) { + SwaggerSecurityScheme.Type.HTTP -> { + when (scheme?.lowercase()) { + "bearer" -> SecurityScheme.Bearer(name) + "basic" -> SecurityScheme.Basic(name) + else -> warn("Unsupported HTTP auth scheme '$scheme' for '$name'") + } + } + + SwaggerSecurityScheme.Type.APIKEY -> { + when (`in`) { + SwaggerSecurityScheme.In.HEADER -> SecurityScheme.ApiKey(name, this.name, ApiKeyLocation.HEADER) + SwaggerSecurityScheme.In.QUERY -> SecurityScheme.ApiKey(name, this.name, ApiKeyLocation.QUERY) + else -> warn("Unsupported API key location '${`in`}' for '$name'") + } + } + + else -> { + warn("Unsupported security scheme type '$type' for '$name'") + } + } + context(_: ComponentSchemaIdentity, _: ComponentSchemas) private fun extractEndpoints(paths: Map): List = paths .asSequence() @@ -213,7 +261,7 @@ object SpecParser { } }.toList() - context(_: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: ComponentSchemaIdentity, _: ComponentSchemas) private fun SwaggerParameter.toParameter(): Parameter = Parameter( name = name ?: "", location = ParameterLocation.parse(`in`) ?: ParameterLocation.QUERY, @@ -224,7 +272,7 @@ object SpecParser { // --- Schema extraction --- - context(_: Raise, _: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: Raise, _: ComponentSchemaIdentity, _: ComponentSchemas) private fun extractSchemaModel(name: String, schema: Schema<*>): SchemaModel { val allOf = schema.allOf?.mapNotNull { it.resolveName() } @@ -285,7 +333,7 @@ object SpecParser { // --- allOf property merging --- - context(_: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: ComponentSchemaIdentity, _: ComponentSchemas) private fun extractAllOfProperties(parentName: String, schema: Schema<*>): Pair, Set> { val topRequired = schema.required.orEmpty().toSet() val contextCreator: (String) -> String? = { propName -> "$parentName.${propName.toPascalCase()}" } @@ -305,7 +353,7 @@ object SpecParser { return finalProperties to required } - context(_: ComponentSchemaIdentity, componentSchemas: ComponentSchemas) + context (_: ComponentSchemaIdentity, componentSchemas: ComponentSchemas) private fun Schema<*>.resolveSubSchema(): Schema<*> = resolveName()?.let { componentSchemas[it] } ?: this /** @@ -321,7 +369,7 @@ object SpecParser { * * Returns: Pair of (unwrapped oneOf refs, synthetic discriminator) or null if pattern not matched. */ - context(componentSchemaIdentity: ComponentSchemaIdentity, componentSchemas: ComponentSchemas) + context (componentSchemaIdentity: ComponentSchemaIdentity, componentSchemas: ComponentSchemas) private fun detectAndUnwrapOneOfWrappers(schema: Schema<*>): Pair, Discriminator>? = nullable { ensure(!schema.oneOf.isNullOrEmpty() && schema.discriminator == null) @@ -354,14 +402,14 @@ object SpecParser { unwrapped.values.toList() to Discriminator(propertyName = "type", mapping = mapping) } - context(_: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: ComponentSchemaIdentity, _: ComponentSchemas) private fun Schema<*>.toTypeRef(contextName: String? = null): TypeRef = contextName?.let { toInlineTypeRef(it) } ?: (resolveName() ?: allOf?.singleOrNull()?.resolveName())?.let(TypeRef::Reference) ?: TypeRef.Unknown.takeIf { (allOf?.size ?: 0) > 1 } ?: resolveByType(contextName) /** Resolves a [TypeRef] based on the schema's structural type/format, ignoring component identity. */ - context(_: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: ComponentSchemaIdentity, _: ComponentSchemas) private fun Schema<*>.resolveByType(contextName: String? = null): TypeRef = when (type) { "string" -> STRING_FORMAT_MAP[format] ?: TypeRef.Primitive(PrimitiveType.STRING) @@ -382,7 +430,7 @@ object SpecParser { else -> TypeRef.Unknown } - context(_: ComponentSchemaIdentity, _: ComponentSchemas) + context (_: ComponentSchemaIdentity, _: ComponentSchemas) private fun Schema<*>.toInlineTypeRef(contextName: String): TypeRef? = takeIf { isInlineObject }?.let { val required = required.orEmpty().toSet() TypeRef.Inline( @@ -392,10 +440,10 @@ object SpecParser { ) } - context(componentSchemaIdentity: ComponentSchemaIdentity) + context (componentSchemaIdentity: ComponentSchemaIdentity) private fun Schema<*>.resolveName(): String? = `$ref`?.removePrefix(SCHEMA_PREFIX) ?: componentSchemaIdentity[this] - context(componentSchemaIdentity: ComponentSchemaIdentity) + context (componentSchemaIdentity: ComponentSchemaIdentity) private val Schema<*>.isInlineObject get(): Boolean = `$ref` == null && this !in componentSchemaIdentity && type == "object" && !properties.isNullOrEmpty() diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGeneratorTest.kt index b2f98a6..d735219 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ApiClientBaseGeneratorTest.kt @@ -1,5 +1,7 @@ package com.avsystem.justworks.core.gen +import com.avsystem.justworks.core.model.ApiKeyLocation +import com.avsystem.justworks.core.model.SecurityScheme import com.squareup.kotlinpoet.ExperimentalKotlinPoetApi import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier @@ -12,14 +14,32 @@ import kotlin.test.assertTrue @OptIn(ExperimentalKotlinPoetApi::class) class ApiClientBaseGeneratorTest { - private val file = ApiClientBaseGenerator.generate() + private val file = ApiClientBaseGenerator.generate(emptyList()) private val classSpec: TypeSpec get() = file.members.filterIsInstance().first { it.name == "ApiClientBase" } private fun topLevelFun(name: String): FunSpec = file.members.filterIsInstance().first { it.name == name } - // -- ApiClientBase class -- + private fun classFor(schemes: List): TypeSpec { + val f = ApiClientBaseGenerator.generate(schemes) + return f.members.filterIsInstance().first { it.name == "ApiClientBase" } + } + + private fun applyAuthBody(schemes: List): String { + val cls = classFor(schemes) + return cls.funSpecs + .first { it.name == "applyAuth" } + .body + .toString() + } + + private fun constructorParamNames(schemes: List): List { + val cls = classFor(schemes) + return cls.primaryConstructor!!.parameters.map { it.name } + } + + // -- ApiClientBase class (no-arg backward compat) -- @Test fun `ApiClientBase is abstract`() { @@ -33,13 +53,10 @@ class ApiClientBaseGeneratorTest { } @Test - fun `ApiClientBase has constructor with baseUrl and token provider`() { + fun `ApiClientBase has constructor with only baseUrl when no schemes`() { val constructor = assertNotNull(classSpec.primaryConstructor) val paramNames = constructor.parameters.map { it.name } - assertTrue("baseUrl" in paramNames) - assertTrue("token" in paramNames) - val tokenParam = constructor.parameters.first { it.name == "token" } - assertEquals("() -> kotlin.String", tokenParam.type.toString(), "token should be a () -> String lambda") + assertEquals(listOf("baseUrl"), paramNames) } @Test @@ -58,14 +75,12 @@ class ApiClientBaseGeneratorTest { } @Test - fun `ApiClientBase has applyAuth function`() { + fun `ApiClientBase has empty applyAuth when no schemes`() { val applyAuth = classSpec.funSpecs.first { it.name == "applyAuth" } assertTrue(KModifier.PROTECTED in applyAuth.modifiers) assertNotNull(applyAuth.receiverType, "Expected HttpRequestBuilder receiver") val body = applyAuth.body.toString() - assertTrue(body.contains("Authorization"), "Expected Authorization header") - assertTrue(body.contains("Bearer"), "Expected Bearer prefix") - assertTrue(body.contains("token()"), "Expected token() invocation") + assertTrue(!body.contains("Authorization"), "Expected no Authorization header for empty schemes") } @Test @@ -131,4 +146,134 @@ class ApiClientBaseGeneratorTest { fun `generates single file named ApiClientBase`() { assertEquals("ApiClientBase", file.name) } + + // -- Security scheme: single Bearer (backward compat) -- + + @Test + fun `single Bearer scheme uses token param name for backward compat`() { + val params = constructorParamNames(listOf(SecurityScheme.Bearer("BearerAuth"))) + assertTrue("baseUrl" in params, "Expected baseUrl param") + assertTrue("token" in params, "Expected token param (single-bearer shorthand)") + } + + @Test + fun `single Bearer scheme generates Bearer auth in applyAuth`() { + val body = applyAuthBody(listOf(SecurityScheme.Bearer("BearerAuth"))) + assertTrue(body.contains("Authorization"), "Expected Authorization header") + assertTrue(body.contains("Bearer"), "Expected Bearer prefix") + assertTrue(body.contains("token()"), "Expected token() invocation") + } + + // -- Security scheme: ApiKey in header -- + + @Test + fun `ApiKey HEADER scheme generates constructor param with Key suffix`() { + val params = constructorParamNames( + listOf(SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER)), + ) + assertTrue("baseUrl" in params, "Expected baseUrl param") + assertTrue("apiKeyHeaderKey" in params, "Expected apiKeyHeaderKey param") + } + + @Test + fun `ApiKey HEADER scheme generates header append in applyAuth`() { + val body = applyAuthBody( + listOf(SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER)), + ) + assertTrue(body.contains("headers"), "Expected headers block") + assertTrue(body.contains("X-API-Key"), "Expected X-API-Key header name") + assertTrue(body.contains("apiKeyHeaderKey()"), "Expected apiKeyHeaderKey() invocation") + } + + // -- Security scheme: ApiKey in query -- + + @Test + fun `ApiKey QUERY scheme generates constructor param with Key suffix`() { + val params = constructorParamNames( + listOf(SecurityScheme.ApiKey("ApiKeyQuery", "api_key", ApiKeyLocation.QUERY)), + ) + assertTrue("baseUrl" in params, "Expected baseUrl param") + assertTrue("apiKeyQueryKey" in params, "Expected apiKeyQueryKey param") + } + + @Test + fun `ApiKey QUERY scheme generates query parameter in applyAuth`() { + val body = applyAuthBody( + listOf(SecurityScheme.ApiKey("ApiKeyQuery", "api_key", ApiKeyLocation.QUERY)), + ) + assertTrue(body.contains("url"), "Expected url block") + assertTrue(body.contains("parameters.append"), "Expected parameters.append") + assertTrue(body.contains("api_key"), "Expected api_key query param name") + assertTrue(body.contains("apiKeyQueryKey()"), "Expected apiKeyQueryKey() invocation") + } + + // -- Security scheme: HTTP Basic -- + + @Test + fun `Basic scheme generates username and password constructor params`() { + val params = constructorParamNames(listOf(SecurityScheme.Basic("BasicAuth"))) + assertTrue("baseUrl" in params, "Expected baseUrl param") + assertTrue("basicAuthUsername" in params, "Expected basicAuthUsername param") + assertTrue("basicAuthPassword" in params, "Expected basicAuthPassword param") + } + + @Test + fun `Basic scheme generates Base64 Authorization header in applyAuth`() { + val body = applyAuthBody(listOf(SecurityScheme.Basic("BasicAuth"))) + assertTrue(body.contains("Authorization"), "Expected Authorization header") + assertTrue(body.contains("Basic"), "Expected Basic prefix") + assertTrue(body.contains("Base64"), "Expected Base64 encoding") + assertTrue(body.contains("basicAuthUsername()"), "Expected basicAuthUsername() invocation") + assertTrue(body.contains("basicAuthPassword()"), "Expected basicAuthPassword() invocation") + } + + // -- Multiple schemes -- + + @Test + fun `multiple schemes generate all constructor params`() { + val params = constructorParamNames( + listOf( + SecurityScheme.Bearer("BearerAuth"), + SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER), + SecurityScheme.Basic("BasicAuth"), + ), + ) + assertTrue("baseUrl" in params, "Expected baseUrl param") + assertTrue("bearerAuthToken" in params, "Expected bearerAuthToken param (multi-scheme uses full name)") + assertTrue("apiKeyHeaderKey" in params, "Expected apiKeyHeaderKey param") + assertTrue("basicAuthUsername" in params, "Expected basicAuthUsername param") + assertTrue("basicAuthPassword" in params, "Expected basicAuthPassword param") + } + + @Test + fun `multiple schemes generate all auth types in applyAuth`() { + val body = applyAuthBody( + listOf( + SecurityScheme.Bearer("BearerAuth"), + SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER), + SecurityScheme.ApiKey("ApiKeyQuery", "api_key", ApiKeyLocation.QUERY), + SecurityScheme.Basic("BasicAuth"), + ), + ) + assertTrue(body.contains("Bearer"), "Expected Bearer in applyAuth") + assertTrue(body.contains("X-API-Key"), "Expected X-API-Key in applyAuth") + assertTrue(body.contains("api_key"), "Expected api_key query param in applyAuth") + assertTrue(body.contains("Basic"), "Expected Basic in applyAuth") + assertTrue(body.contains("Base64"), "Expected Base64 in applyAuth") + } + + // -- Empty schemes (spec with no security) -- + + @Test + fun `empty schemes list generates only baseUrl constructor param`() { + val params = constructorParamNames(emptyList()) + assertEquals(listOf("baseUrl"), params, "Expected only baseUrl param when no security schemes") + } + + @Test + fun `empty schemes list generates empty applyAuth body`() { + val body = applyAuthBody(emptyList()) + assertTrue(!body.contains("headers"), "Expected no headers block for empty schemes") + assertTrue(!body.contains("url"), "Expected no url block for empty schemes") + } } diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt index 3f60c32..4ef599b 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt @@ -1,5 +1,6 @@ package com.avsystem.justworks.core.gen +import com.avsystem.justworks.core.model.ApiKeyLocation import com.avsystem.justworks.core.model.ApiSpec import com.avsystem.justworks.core.model.Endpoint import com.avsystem.justworks.core.model.HttpMethod @@ -8,6 +9,7 @@ import com.avsystem.justworks.core.model.ParameterLocation import com.avsystem.justworks.core.model.PrimitiveType import com.avsystem.justworks.core.model.RequestBody import com.avsystem.justworks.core.model.Response +import com.avsystem.justworks.core.model.SecurityScheme import com.avsystem.justworks.core.model.TypeRef import com.squareup.kotlinpoet.ExperimentalKotlinPoetApi import com.squareup.kotlinpoet.KModifier @@ -23,12 +25,13 @@ class ClientGeneratorTest { private val modelPackage = "com.example.model" private val generator = ClientGenerator(apiPackage, modelPackage) - private fun spec(endpoints: List) = ApiSpec( + private fun spec(endpoints: List, securitySchemes: List = emptyList()) = ApiSpec( title = "Test", version = "1.0", endpoints = endpoints, schemas = emptyList(), enums = emptyList(), + securitySchemes = securitySchemes, ) private fun endpoint( @@ -53,8 +56,8 @@ class ClientGeneratorTest { responses = responses, ) - private fun clientClass(endpoints: List): TypeSpec { - val files = generator.generate(spec(endpoints)) + private fun clientClass(endpoints: List, securitySchemes: List = emptyList()): TypeSpec { + val files = generator.generate(spec(endpoints, securitySchemes)) return files .first() .members @@ -294,14 +297,14 @@ class ClientGeneratorTest { assertEquals("kotlin.String", baseUrl.type.toString()) } - // -- AUTH-01: Client constructor has token parameter -- + // -- No security: constructor has only baseUrl -- @Test - fun `client constructor has token provider parameter`() { + fun `no security schemes generates constructor with only baseUrl`() { val cls = clientClass(listOf(endpoint())) val constructor = assertNotNull(cls.primaryConstructor) - val token = constructor.parameters.first { it.name == "token" } - assertEquals("() -> kotlin.String", token.type.toString(), "token should be a () -> String lambda") + val paramNames = constructor.parameters.map { it.name } + assertEquals(listOf("baseUrl"), paramNames) } // -- Pitfall 3: Untagged endpoints go to DefaultClient -- @@ -425,4 +428,98 @@ class ClientGeneratorTest { val body = funSpec.body.toString() assertTrue(body.contains("toEmptyResult"), "Expected toEmptyResult call") } + + // -- SECU: Security-aware constructor generation -- + + @Test + fun `no securitySchemes generates constructor with only baseUrl`() { + val cls = clientClass(listOf(endpoint())) + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertEquals(listOf("baseUrl"), paramNames) + } + + @Test + fun `ApiKey HEADER scheme generates constructor with baseUrl and apiKey param`() { + val cls = clientClass( + listOf(endpoint()), + listOf(SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER)), + ) + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertTrue("baseUrl" in paramNames, "Expected baseUrl param") + assertTrue("apiKeyHeaderKey" in paramNames, "Expected apiKeyHeaderKey param") + } + + @Test + fun `Basic scheme generates constructor with baseUrl, username, and password`() { + val cls = clientClass( + listOf(endpoint()), + listOf(SecurityScheme.Basic("BasicAuth")), + ) + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertTrue("baseUrl" in paramNames, "Expected baseUrl param") + assertTrue("basicAuthUsername" in paramNames, "Expected basicAuthUsername param") + assertTrue("basicAuthPassword" in paramNames, "Expected basicAuthPassword param") + } + + @Test + fun `multiple schemes generate all constructor params and pass all to super`() { + val cls = clientClass( + listOf(endpoint()), + listOf( + SecurityScheme.Bearer("BearerAuth"), + SecurityScheme.ApiKey("ApiKeyHeader", "X-API-Key", ApiKeyLocation.HEADER), + ), + ) + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertTrue("baseUrl" in paramNames, "Expected baseUrl param") + assertTrue("bearerAuthToken" in paramNames, "Expected bearerAuthToken param") + assertTrue("apiKeyHeaderKey" in paramNames, "Expected apiKeyHeaderKey param") + + // Verify superclass constructor params match + val superParams = cls.superclassConstructorParameters.map { it.toString().trim() } + assertTrue(superParams.contains("baseUrl"), "Expected baseUrl passed to super") + assertTrue(superParams.contains("bearerAuthToken"), "Expected bearerAuthToken passed to super") + assertTrue(superParams.contains("apiKeyHeaderKey"), "Expected apiKeyHeaderKey passed to super") + } + + @Test + fun `explicit empty securitySchemes generates constructor with only baseUrl`() { + // Explicit empty securitySchemes = spec has security: [] (no auth required) + val spec = ApiSpec( + title = "Test", + version = "1.0", + endpoints = listOf(endpoint()), + schemas = emptyList(), + enums = emptyList(), + securitySchemes = emptyList(), + ) + val files = generator.generate(spec) + val cls = files + .first() + .members + .filterIsInstance() + .first() + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertEquals( + listOf("baseUrl"), + paramNames, + "Expected only baseUrl param when security is explicitly empty", + ) + } + + @Test + fun `single Bearer scheme uses token param name as shorthand`() { + val cls = clientClass( + listOf(endpoint()), + listOf(SecurityScheme.Bearer("BearerAuth")), + ) + val constructor = assertNotNull(cls.primaryConstructor) + val paramNames = constructor.parameters.map { it.name } + assertTrue("token" in paramNames, "Expected token param (single-bearer shorthand)") + } } diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt index cbf97de..4316169 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt @@ -78,7 +78,7 @@ class IntegrationTest { val spec = parseSpec(fixture).apiSpec if (spec.endpoints.isEmpty()) continue - val apiClientBaseFile = ApiClientBaseGenerator.generate() + val apiClientBaseFile = ApiClientBaseGenerator.generate(spec.securitySchemes) assertNotNull(apiClientBaseFile, "$fixture: ApiClientBaseGenerator should produce output") val source = apiClientBaseFile.toString() @@ -111,7 +111,7 @@ class IntegrationTest { ) } - val apiClientBaseFile = ApiClientBaseGenerator.generate() + val apiClientBaseFile = ApiClientBaseGenerator.generate(spec.securitySchemes) assertNotNull(apiClientBaseFile, "$fixture: ApiClientBaseGenerator should produce output") } } diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt index 649d8f5..62768e7 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt @@ -24,6 +24,7 @@ class ModelGeneratorPolymorphicTest { endpoints = emptyList(), schemas = schemas, enums = enums, + securitySchemes = emptyList(), ) private fun schema( diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt index 9507b32..00fc0ed 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt @@ -29,6 +29,7 @@ class ModelGeneratorTest { endpoints = emptyList(), schemas = schemas, enums = enums, + securitySchemes = emptyList(), ) private val petSchema = @@ -1368,6 +1369,7 @@ class ModelGeneratorTest { endpoints = listOf(endpoint), schemas = emptyList(), enums = emptyList(), + securitySchemes = emptyList(), ) val files = generator.generate(apiSpec) val uuidSerializerFile = files.find { it.name == "UuidSerializer" } diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/parser/SpecParserSecurityTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/parser/SpecParserSecurityTest.kt new file mode 100644 index 0000000..92d1c3a --- /dev/null +++ b/core/src/test/kotlin/com/avsystem/justworks/core/parser/SpecParserSecurityTest.kt @@ -0,0 +1,70 @@ +package com.avsystem.justworks.core.parser + +import com.avsystem.justworks.core.model.ApiKeyLocation +import com.avsystem.justworks.core.model.ApiSpec +import com.avsystem.justworks.core.model.SecurityScheme +import org.junit.jupiter.api.TestInstance +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class SpecParserSecurityTest : SpecParserTestBase() { + private lateinit var apiSpec: ApiSpec + + @BeforeTest + fun setUp() { + if (!::apiSpec.isInitialized) { + apiSpec = parseSpec(loadResource("security-schemes-spec.yaml")) + } + } + + @Test + fun `parses exactly 4 security schemes from fixture`() { + assertEquals(4, apiSpec.securitySchemes.size) + } + + @Test + fun `parses Bearer security scheme`() { + val bearer = apiSpec.securitySchemes.filterIsInstance() + assertEquals(1, bearer.size) + assertEquals("BearerAuth", bearer.single().name) + } + + @Test + fun `parses ApiKey header security scheme`() { + val apiKeys = apiSpec.securitySchemes.filterIsInstance() + val header = apiKeys.single { it.location == ApiKeyLocation.HEADER } + assertEquals("ApiKeyHeader", header.name) + assertEquals("X-API-Key", header.parameterName) + } + + @Test + fun `parses ApiKey query security scheme`() { + val apiKeys = apiSpec.securitySchemes.filterIsInstance() + val query = apiKeys.single { it.location == ApiKeyLocation.QUERY } + assertEquals("ApiKeyQuery", query.name) + assertEquals("api_key", query.parameterName) + } + + @Test + fun `parses Basic security scheme`() { + val basic = apiSpec.securitySchemes.filterIsInstance() + assertEquals(1, basic.size) + assertEquals("BasicAuth", basic.single().name) + } + + @Test + fun `excludes unreferenced OAuth2 scheme`() { + val names = apiSpec.securitySchemes.map { it.name } + assertTrue("UnusedOAuth" !in names, "UnusedOAuth should not be in parsed schemes") + } + + @Test + fun `spec without security field produces empty securitySchemes`() { + val petstore = parseSpec(loadResource("petstore.yaml")) + assertTrue(petstore.securitySchemes.isEmpty(), "petstore should have no security schemes") + } +} diff --git a/core/src/test/resources/security-schemes-spec.yaml b/core/src/test/resources/security-schemes-spec.yaml new file mode 100644 index 0000000..029ee25 --- /dev/null +++ b/core/src/test/resources/security-schemes-spec.yaml @@ -0,0 +1,43 @@ +openapi: "3.0.3" +info: + title: Security Schemes Test API + version: "1.0.0" + +components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer + ApiKeyHeader: + type: apiKey + in: header + name: X-API-Key + ApiKeyQuery: + type: apiKey + in: query + name: api_key + BasicAuth: + type: http + scheme: basic + UnusedOAuth: + type: oauth2 + flows: + implicit: + authorizationUrl: https://example.com/oauth/authorize + scopes: + read: Read access + +security: + - BearerAuth: [] + - ApiKeyHeader: [] + - ApiKeyQuery: [] + - BasicAuth: [] + +paths: + /health: + get: + operationId: getHealth + summary: Health check + responses: + "200": + description: OK diff --git a/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt b/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt index c95eeb3..6347cf9 100644 --- a/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt +++ b/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt @@ -495,6 +495,125 @@ class JustworksPluginFunctionalTest { assertTrue(result.output.contains("Invalid spec name 'pet-store'")) } + @Test + fun `spec with security schemes generates ApiClientBase with applyAuth body`() { + writeFile( + "api/secured.yaml", + """ + openapi: '3.0.0' + info: + title: Secured API + version: '1.0' + paths: + /data: + get: + operationId: getData + summary: Get data + tags: + - data + responses: + '200': + description: OK + content: + application/json: + schema: + type: object + properties: + value: + type: string + components: + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + BasicAuth: + type: http + scheme: basic + security: + - ApiKeyAuth: [] + - BasicAuth: [] + """.trimIndent(), + ) + + writeFile( + "build.gradle.kts", + """ + plugins { + kotlin("jvm") version "2.3.0" + kotlin("plugin.serialization") version "2.3.0" + id("com.avsystem.justworks") + } + + repositories { + mavenCentral() + } + + dependencies { + implementation("org.jetbrains.kotlinx:kotlinx-serialization-core:1.8.1") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.8.1") + implementation("io.ktor:ktor-client-core:3.1.1") + implementation("io.ktor:ktor-client-content-negotiation:3.1.1") + implementation("io.ktor:ktor-serialization-kotlinx-json:3.1.1") + implementation("io.arrow-kt:arrow-core:2.2.1.1") + } + + kotlin { + compilerOptions { + freeCompilerArgs.add("-Xcontext-parameters") + } + } + + justworks { + specs { + register("secured") { + specFile = file("api/secured.yaml") + packageName = "com.example.secured" + } + } + } + """.trimIndent(), + ) + + val result = runner("justworksGenerateSecured").build() + + assertEquals( + TaskOutcome.SUCCESS, + result.task(":justworksGenerateSecured")?.outcome, + ) + + val apiClientBase = projectDir + .resolve("build/generated/justworks/shared/kotlin/com/avsystem/justworks/ApiClientBase.kt") + assertTrue(apiClientBase.exists(), "ApiClientBase.kt should exist") + + val content = apiClientBase.readText() + assertTrue(content.contains("apiKeyAuthKey"), "Should contain apiKeyAuthKey param") + assertTrue(content.contains("basicAuthUsername"), "Should contain basicAuthUsername param") + assertTrue(content.contains("basicAuthPassword"), "Should contain basicAuthPassword param") + assertTrue(content.contains("X-API-Key"), "Should contain X-API-Key header name") + assertTrue(content.contains("applyAuth"), "Should contain applyAuth method") + assertTrue(content.contains("Authorization"), "Should contain Authorization header for Basic auth") + assertFalse( + content.contains("token: () -> String"), + "Should NOT contain backward-compat token param when explicit security schemes present", + ) + } + + @Test + fun `spec without security schemes generates ApiClientBase with no auth params`() { + writeBuildFile() + + runner("justworksGenerateMain").build() + + val apiClientBase = projectDir + .resolve("build/generated/justworks/shared/kotlin/com/avsystem/justworks/ApiClientBase.kt") + assertTrue(apiClientBase.exists(), "ApiClientBase.kt should exist") + + val content = apiClientBase.readText() + assertTrue(!content.contains("token"), "Should NOT contain token param when no security schemes") + assertTrue(!content.contains("Bearer"), "Should NOT contain Bearer when no security schemes") + } + @Test fun `empty specs container logs warning`() { writeFile( diff --git a/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksPlugin.kt b/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksPlugin.kt index 707dbbd..92df1f9 100644 --- a/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksPlugin.kt +++ b/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksPlugin.kt @@ -60,6 +60,11 @@ class JustworksPlugin : Plugin { task.description = "Generate Kotlin client from '${spec.name}' OpenAPI spec" } + // Wire spec file into shared types task for security scheme extraction + sharedTypesTask.configure { task -> + task.specFiles.from(spec.specFile) + } + // Wire spec task into aggregate task generateAllTask.configure { it.dependsOn(specTask) } } diff --git a/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksSharedTypesTask.kt b/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksSharedTypesTask.kt index a6b915e..7ce74f2 100644 --- a/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksSharedTypesTask.kt +++ b/plugin/src/main/kotlin/com/avsystem/justworks/gradle/JustworksSharedTypesTask.kt @@ -1,18 +1,33 @@ package com.avsystem.justworks.gradle import com.avsystem.justworks.core.gen.CodeGenerator +import com.avsystem.justworks.core.parser.ParseResult +import com.avsystem.justworks.core.parser.SpecParser import org.gradle.api.DefaultTask +import org.gradle.api.file.ConfigurableFileCollection import org.gradle.api.file.DirectoryProperty import org.gradle.api.tasks.CacheableTask +import org.gradle.api.tasks.InputFiles import org.gradle.api.tasks.OutputDirectory +import org.gradle.api.tasks.PathSensitive +import org.gradle.api.tasks.PathSensitivity import org.gradle.api.tasks.TaskAction /** - * Gradle task that generates shared types (HttpError, Success) once + * Gradle task that generates shared types (HttpError, Success, ApiClientBase) once * to a fixed output directory shared across all spec configurations. + * + * When [specFiles] are configured, the task parses them to extract security schemes + * and passes them to ApiClientBase generation so the generated auth code reflects + * the spec's security configuration. */ @CacheableTask abstract class JustworksSharedTypesTask : DefaultTask() { + /** All configured spec files — used to extract security schemes. */ + @get:InputFiles + @get:PathSensitive(PathSensitivity.RELATIVE) + abstract val specFiles: ConfigurableFileCollection + /** Output directory for shared type files. */ @get:OutputDirectory abstract val outputDir: DirectoryProperty @@ -21,8 +36,17 @@ abstract class JustworksSharedTypesTask : DefaultTask() { fun generate() { val outDir = outputDir.get().asFile.recreateDirectory() - val count = CodeGenerator.generateSharedTypes(outDir) + val specs = specFiles.files.sortedBy { it.path }.mapNotNull { file -> + when (val result = SpecParser.parse(file)) { + is ParseResult.Success -> result.apiSpec + is ParseResult.Failure -> { + logger.warn("Failed to parse spec '${file.name}': ${result.error}") + null + } + } + } + val count = CodeGenerator.generateSharedTypes(outDir, specs) logger.lifecycle("Generated $count shared type files") } }