diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/Memo.kt b/core/src/main/kotlin/com/avsystem/justworks/core/Memo.kt new file mode 100644 index 0000000..7b671e4 --- /dev/null +++ b/core/src/main/kotlin/com/avsystem/justworks/core/Memo.kt @@ -0,0 +1,29 @@ +package com.avsystem.justworks.core + +import arrow.atomic.Atomic +import arrow.atomic.update + +@JvmInline +value class MemoScope private constructor(private val memos: Atomic>>) { + constructor() : this(Atomic(emptySet())) + + fun register(memo: Memo<*>) { + memos.update { it + memo } + } + + fun reset() { + memos.get().forEach { it.reset() } + } +} + +class Memo(private val compute: () -> T) { + private val holder = Atomic(lazy(compute)) + + operator fun getValue(thisRef: Any?, property: Any?): T = holder.get().value + + fun reset() { + holder.set(lazy(compute)) + } +} + +fun memoized(memoScope: MemoScope, compute: () -> T): Memo = Memo(compute).also(memoScope::register) diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt b/core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt index 9547a06..9ec4311 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt @@ -2,4 +2,6 @@ package com.avsystem.justworks.core import kotlin.enums.enumEntries +internal const val SCHEMA_PREFIX = "#/components/schemas/" + inline fun > String.toEnumOrNull(): T? = enumEntries().find { it.name.equals(this, true) } diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt index 9c86b35..64f79ee 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt @@ -19,17 +19,20 @@ object CodeGenerator { modelPackage: String, apiPackage: String, outputDir: File, - ): Result = context(ModelPackage(modelPackage), ApiPackage(apiPackage)) { - val modelRegistry = NameRegistry() - val apiRegistry = NameRegistry() + ): Result { + val hierarchy = Hierarchy(ModelPackage(modelPackage)).apply { addSchemas(spec.schemas) } - val (modelFiles, resolvedSpec) = ModelGenerator.generateWithResolvedSpec(spec, modelRegistry) + val (modelFiles, resolvedSpec) = context(hierarchy, NameRegistry()) { + ModelGenerator.generateWithResolvedSpec(spec) + } modelFiles.forEach { it.writeTo(outputDir) } val hasPolymorphicTypes = modelFiles.any { it.name == SERIALIZERS_MODULE.simpleName } - val clientFiles = ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes, apiRegistry) + val clientFiles = context(hierarchy, ApiPackage(apiPackage), NameRegistry()) { + ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes) + } clientFiles.forEach { it.writeTo(outputDir) } diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/Hierarchy.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Hierarchy.kt new file mode 100644 index 0000000..6767c0d --- /dev/null +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Hierarchy.kt @@ -0,0 +1,80 @@ +package com.avsystem.justworks.core.gen + +import com.avsystem.justworks.core.MemoScope +import com.avsystem.justworks.core.memoized +import com.avsystem.justworks.core.model.SchemaModel +import com.avsystem.justworks.core.model.TypeRef +import com.squareup.kotlinpoet.ClassName + +internal class Hierarchy(val modelPackage: ModelPackage) { + private val schemas = mutableSetOf() + + /** + * Updates the underlying schemas and invalidates all cached derived views. + * This is necessary when schemas are updated (e.g., after inlining types). + */ + private val memoScope = MemoScope() + + fun addSchemas(newSchemas: List) { + schemas += newSchemas + memoScope.reset() + } + + /** All schemas indexed by name for quick lookup. */ + val schemasById: Map by memoized(memoScope) { + schemas.associateBy { it.name } + } + + /** Schemas that define polymorphic variants via oneOf or anyOf. */ + private val polymorphicSchemas: List by memoized(memoScope) { + schemas.filterNot { it.variants().isNullOrEmpty() } + } + + /** Maps parent schema name to its variant schema names (for both oneOf and anyOf). */ + val sealedHierarchies: Map> by memoized(memoScope) { + polymorphicSchemas + .associate { schema -> + schema.name to schema + .variants() + ?.filterIsInstance() + ?.map { it.schemaName } + .orEmpty() + } + } + + /** Parent schema names that use anyOf without a discriminator (JsonContentPolymorphicSerializer pattern). */ + val anyOfWithoutDiscriminator: Set by memoized(memoScope) { + polymorphicSchemas + .asSequence() + .filter { !it.anyOf.isNullOrEmpty() && it.discriminator == null } + .map { it.name } + .toSet() + } + + /** Inverse of [sealedHierarchies] for anyOf-without-discriminator: variant name to its parent names. */ + val anyOfParents: Map> by memoized(memoScope) { + sealedHierarchies + .asSequence() + .filter { (parent, _) -> parent in anyOfWithoutDiscriminator } + .flatMap { (parent, variants) -> variants.map { it to parent } } + .groupBy({ it.first }, { it.second }) + .mapValues { (_, parents) -> parents.toSet() } + } + + /** Maps schema name to its [ClassName], using nested class for discriminated hierarchy variants. */ + private val lookup: Map by memoized(memoScope) { + sealedHierarchies + .asSequence() + .filterNot { (parent, _) -> parent in anyOfWithoutDiscriminator } + .flatMap { (parent, variants) -> + val parentClass = ClassName(modelPackage, parent) + variants.map { variant -> variant to parentClass.nestedClass(variant) } + + (parent to parentClass) + }.toMap() + } + + /** Resolves a schema name to its [ClassName], falling back to a flat top-level class. */ + operator fun get(name: String): ClassName = lookup[name] ?: ClassName(modelPackage, name) +} + +private fun SchemaModel.variants() = oneOf ?: anyOf diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt index 7391e51..9d326eb 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt @@ -1,11 +1,12 @@ package com.avsystem.justworks.core.gen +import com.avsystem.justworks.core.SCHEMA_PREFIX import com.avsystem.justworks.core.model.PrimitiveType import com.avsystem.justworks.core.model.PropertyModel +import com.avsystem.justworks.core.model.SchemaModel import com.avsystem.justworks.core.model.TypeRef import com.squareup.kotlinpoet.BOOLEAN import com.squareup.kotlinpoet.BYTE_ARRAY -import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.DOUBLE import com.squareup.kotlinpoet.FLOAT import com.squareup.kotlinpoet.INT @@ -28,7 +29,7 @@ internal val TypeRef.requiredProperties: Set is TypeRef.Array, is TypeRef.Map, is TypeRef.Primitive, is TypeRef.Reference, TypeRef.Unknown -> emptySet() } -context(modelPackage: ModelPackage) +context(hierarchy: Hierarchy) internal fun TypeRef.toTypeName(): TypeName = when (this) { is TypeRef.Primitive -> { when (type) { @@ -54,7 +55,7 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) { } is TypeRef.Reference -> { - ClassName(modelPackage, schemaName) + hierarchy[schemaName] } is TypeRef.Inline -> { @@ -67,3 +68,13 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) { } internal fun TypeRef.isBinaryUpload(): Boolean = this is TypeRef.Primitive && this.type == PrimitiveType.BYTE_ARRAY + +/** + * Resolves the @SerialName value for a variant within a oneOf schema. + */ +internal fun SchemaModel.resolveSerialName(variantSchemaName: String): String = discriminator + ?.mapping + ?.firstNotNullOfOrNull { (serialName, refPath) -> + serialName.takeIf { refPath.removePrefix(SCHEMA_PREFIX) == variantSchemaName } + } + ?: variantSchemaName diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ClientGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ClientGenerator.kt index 30e18a2..fe34689 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ClientGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ClientGenerator.kt @@ -9,7 +9,7 @@ import com.avsystem.justworks.core.gen.GENERATED_SERIALIZERS_MODULE import com.avsystem.justworks.core.gen.HTTP_CLIENT import com.avsystem.justworks.core.gen.HTTP_ERROR import com.avsystem.justworks.core.gen.HTTP_SUCCESS -import com.avsystem.justworks.core.gen.ModelPackage +import com.avsystem.justworks.core.gen.Hierarchy import com.avsystem.justworks.core.gen.NameRegistry import com.avsystem.justworks.core.gen.RAISE import com.avsystem.justworks.core.gen.TOKEN @@ -51,29 +51,22 @@ internal object ClientGenerator { private const val DEFAULT_TAG = "Default" private const val API_SUFFIX = "Api" - context(_: ModelPackage, _: ApiPackage) - fun generate( - spec: ApiSpec, - hasPolymorphicTypes: Boolean, - nameRegistry: NameRegistry, - ): List { + context(_: Hierarchy, _: ApiPackage, _: NameRegistry) + fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean): List { val grouped = spec.endpoints.groupBy { it.tags.firstOrNull() ?: DEFAULT_TAG } - return grouped.map { (tag, endpoints) -> - generateClientFile(tag, endpoints, hasPolymorphicTypes, nameRegistry) - } + return grouped.map { (tag, endpoints) -> generateClientFile(tag, endpoints, hasPolymorphicTypes) } } - context(modelPackage: ModelPackage, apiPackage: ApiPackage) + context(hierarchy: Hierarchy, apiPackage: ApiPackage, nameRegistry: NameRegistry) private fun generateClientFile( tag: String, endpoints: List, hasPolymorphicTypes: Boolean, - nameRegistry: NameRegistry, ): FileSpec { val className = ClassName(apiPackage, nameRegistry.register("${tag.toPascalCase()}$API_SUFFIX")) val clientInitializer = if (hasPolymorphicTypes) { - val generatedSerializersModule = MemberName(modelPackage, GENERATED_SERIALIZERS_MODULE) + val generatedSerializersModule = MemberName(hierarchy.modelPackage, GENERATED_SERIALIZERS_MODULE) CodeBlock.of("${CREATE_HTTP_CLIENT}(%M)", generatedSerializersModule) } else { CodeBlock.of("${CREATE_HTTP_CLIENT}()") @@ -101,8 +94,9 @@ internal object ClientGenerator { .primaryConstructor(primaryConstructor) .addProperty(httpClientProperty) - val methodRegistry = NameRegistry() - classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it, methodRegistry) }) + context(NameRegistry()) { + classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it) }) + } return FileSpec .builder(className) @@ -110,8 +104,8 @@ internal object ClientGenerator { .build() } - context(_: ModelPackage) - private fun generateEndpointFunction(endpoint: Endpoint, methodRegistry: NameRegistry): FunSpec { + context(_: Hierarchy, methodRegistry: NameRegistry) + private fun generateEndpointFunction(endpoint: Endpoint): FunSpec { val functionName = methodRegistry.register(endpoint.operationId.toCamelCase()) val returnBodyType = resolveReturnType(endpoint) val returnType = HTTP_SUCCESS.parameterizedBy(returnBodyType) @@ -166,7 +160,7 @@ internal object ClientGenerator { return funBuilder.build() } - context(_: ModelPackage) + context(_: Hierarchy) private fun resolveReturnType(endpoint: Endpoint): TypeName = endpoint.responses.entries .asSequence() .filter { it.key.startsWith("2") } diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ParametersGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ParametersGenerator.kt index 845b13f..dbec75f 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ParametersGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ParametersGenerator.kt @@ -3,7 +3,7 @@ package com.avsystem.justworks.core.gen.client import com.avsystem.justworks.core.gen.BODY import com.avsystem.justworks.core.gen.CHANNEL_PROVIDER import com.avsystem.justworks.core.gen.CONTENT_TYPE_CLASS -import com.avsystem.justworks.core.gen.ModelPackage +import com.avsystem.justworks.core.gen.Hierarchy import com.avsystem.justworks.core.gen.isBinaryUpload import com.avsystem.justworks.core.gen.properties import com.avsystem.justworks.core.gen.requiredProperties @@ -16,7 +16,7 @@ import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.STRING internal object ParametersGenerator { - context(_: ModelPackage) + context(_: Hierarchy) fun buildMultipartParameters(requestBody: RequestBody): List = requestBody.schema.properties.flatMap { prop -> val name = prop.name.toCamelCase() @@ -33,13 +33,13 @@ internal object ParametersGenerator { } } - context(_: ModelPackage) + context(_: Hierarchy) fun buildFormParameters(requestBody: RequestBody): List = requestBody.schema.properties.map { prop -> val isRequired = requestBody.required && prop.name in requestBody.schema.requiredProperties buildNullableParameter(prop.type, prop.name, isRequired) } - context(_: ModelPackage) + context(_: Hierarchy) fun buildNullableParameter( typeRef: TypeRef, name: String, @@ -50,7 +50,7 @@ internal object ParametersGenerator { return builder.build() } - context(_: ModelPackage) + context(_: Hierarchy) fun buildBodyParams(requestBody: RequestBody) = when (requestBody.contentType) { ContentType.MULTIPART_FORM_DATA -> buildMultipartParameters(requestBody) ContentType.FORM_URL_ENCODED -> buildFormParameters(requestBody) diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/model/ModelGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/model/ModelGenerator.kt index 4281efe..95ced11 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/model/ModelGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/model/ModelGenerator.kt @@ -5,6 +5,7 @@ import com.avsystem.justworks.core.gen.DECODER import com.avsystem.justworks.core.gen.ENCODER import com.avsystem.justworks.core.gen.EXPERIMENTAL_SERIALIZATION_API import com.avsystem.justworks.core.gen.EXPERIMENTAL_UUID_API +import com.avsystem.justworks.core.gen.Hierarchy import com.avsystem.justworks.core.gen.INSTANT import com.avsystem.justworks.core.gen.InlineSchemaKey import com.avsystem.justworks.core.gen.JSON_CLASS_DISCRIMINATOR @@ -13,7 +14,6 @@ import com.avsystem.justworks.core.gen.JSON_ELEMENT import com.avsystem.justworks.core.gen.JSON_OBJECT_EXT import com.avsystem.justworks.core.gen.K_SERIALIZER import com.avsystem.justworks.core.gen.LOCAL_DATE -import com.avsystem.justworks.core.gen.ModelPackage import com.avsystem.justworks.core.gen.NameRegistry import com.avsystem.justworks.core.gen.OPT_IN import com.avsystem.justworks.core.gen.PRIMITIVE_KIND @@ -28,6 +28,7 @@ import com.avsystem.justworks.core.gen.UUID_SERIALIZER import com.avsystem.justworks.core.gen.UUID_TYPE import com.avsystem.justworks.core.gen.invoke import com.avsystem.justworks.core.gen.resolveInlineTypes +import com.avsystem.justworks.core.gen.resolveSerialName import com.avsystem.justworks.core.gen.resolveTypeRef import com.avsystem.justworks.core.gen.sanitizeKdoc import com.avsystem.justworks.core.gen.shared.SerializersModuleGenerator @@ -59,22 +60,21 @@ import kotlinx.datetime.LocalDate import kotlin.time.Instant /** - * Generates KotlinPoet [com.squareup.kotlinpoet.FileSpec] instances from an [com.avsystem.justworks.core.model.ApiSpec]. + * Generates KotlinPoet [FileSpec] instances from an [ApiSpec]. * - * Produces one file per [com.avsystem.justworks.core.model.SchemaModel] (data class, sealed interface, or allOf composed class) - * and one file per [com.avsystem.justworks.core.model.EnumModel] (enum class), all annotated with kotlinx.serialization annotations. + * Produces one file per [SchemaModel] (data class, sealed class hierarchy, or allOf composed class) + * and one file per [EnumModel] (enum class), all annotated with kotlinx.serialization annotations. */ internal object ModelGenerator { data class GenerateResult(val files: List, val resolvedSpec: ApiSpec) - context(_: ModelPackage) - fun generate(spec: ApiSpec, nameRegistry: NameRegistry): List = - generateWithResolvedSpec(spec, nameRegistry).files + context(_: Hierarchy, _: NameRegistry) + fun generate(spec: ApiSpec): List = generateWithResolvedSpec(spec).files - context(modelPackage: ModelPackage) - fun generateWithResolvedSpec(spec: ApiSpec, nameRegistry: NameRegistry): GenerateResult { + context(hierarchy: Hierarchy, nameRegistry: NameRegistry) + fun generateWithResolvedSpec(spec: ApiSpec): GenerateResult { ensureReserved(spec, nameRegistry) - val (inlineSchemas, nameMap) = collectAllInlineSchemas(spec, nameRegistry) + val (inlineSchemas, nameMap) = collectAllInlineSchemas(spec) val resolvedSpec = spec.resolveInlineTypes(nameMap) val resolvedInlineSchemas = inlineSchemas.map { schema -> @@ -85,62 +85,38 @@ internal object ModelGenerator { ) } - val files = context(buildHierarchyInfo(resolvedSpec.schemas)) { - val schemaFiles = resolvedSpec.schemas.flatMap { generateSchemaFiles(it) } + hierarchy.addSchemas(resolvedSpec.schemas + resolvedInlineSchemas) - val inlineSchemaFiles = resolvedInlineSchemas.map { generateDataClass(it) } - - val enumFiles = resolvedSpec.enums.map { generateEnumClass(it) } - - val serializersModuleFile = SerializersModuleGenerator.generate() + val nestedVariantNames = hierarchy.sealedHierarchies + .asSequence() + .filterNot { (key, _) -> key in hierarchy.anyOfWithoutDiscriminator } + .flatMap { (_, names) -> names } + .toSet() - val uuidSerializerFile = if (resolvedSpec.usesUuid()) generateUuidSerializer() else null + val schemaFiles = resolvedSpec.schemas + .asSequence() + .filterNot { it.name in nestedVariantNames } + .flatMap { generateSchemaFiles(it) } + .toList() - schemaFiles + inlineSchemaFiles + enumFiles + listOfNotNull(serializersModuleFile, uuidSerializerFile) + val inlineSchemaFiles = resolvedInlineSchemas.map { + if (it.isNested) { + generateNestedInlineClass(it) + } else { + generateDataClass(it) + } } - return GenerateResult(files, resolvedSpec) - } - - data class HierarchyInfo( - val sealedHierarchies: Map>, - val variantParents: Map>, - val anyOfWithoutDiscriminator: Set, - val schemas: List, - ) - - context(modelPackage: ModelPackage) - private fun buildHierarchyInfo(schemas: List): HierarchyInfo { - fun SchemaModel.variants() = oneOf ?: anyOf ?: emptyList() - - val polymorphicSchemas = schemas.filter { it.variants().isNotEmpty() } + val enumFiles = resolvedSpec.enums.map { generateEnumClass(it) } - val sealedHierarchies = polymorphicSchemas.associate { schema -> - schema.name to schema - .variants() - .asSequence() - .filterIsInstance() - .map { it.schemaName } - .toList() - } + val serializersModuleFile = SerializersModuleGenerator.generate() - val variantParents = polymorphicSchemas - .asSequence() - .flatMap { schema -> - val parentClass = ClassName(modelPackage, schema.name) - schema.variants().filterIsInstance().map { ref -> - ref.schemaName to (parentClass to resolveSerialName(schema, ref.schemaName)) - } - }.groupBy({ it.first }, { it.second }) - .mapValues { (_, entries) -> entries.toMap() } + val uuidSerializerFile = if (resolvedSpec.usesUuid()) generateUuidSerializer() else null - val anyOfWithoutDiscriminator = polymorphicSchemas - .asSequence() - .filter { !it.anyOf.isNullOrEmpty() && it.discriminator == null } - .map { it.name } - .toSet() + val files = + schemaFiles + inlineSchemaFiles + enumFiles + listOfNotNull(serializersModuleFile, uuidSerializerFile) - return HierarchyInfo(sealedHierarchies, variantParents, anyOfWithoutDiscriminator, schemas) + return GenerateResult(files, resolvedSpec) } /** @@ -155,10 +131,8 @@ internal object ModelGenerator { nameRegistry.reserve(SERIALIZERS_MODULE.simpleName) } - private fun collectAllInlineSchemas( - spec: ApiSpec, - nameRegistry: NameRegistry, - ): Pair, Map> { + context(nameRegistry: NameRegistry) + private fun collectAllInlineSchemas(spec: ApiSpec): Pair, Map> { val endpointRefs = spec.endpoints.flatMap { endpoint -> val requestRef = endpoint.requestBody?.schema val responseRefs = endpoint.responses.values.map { it.schema } @@ -192,13 +166,16 @@ internal object ModelGenerator { return schemas to nameMap } - context(hierarchy: HierarchyInfo, _: ModelPackage) + context(hierarchy: Hierarchy) private fun generateSchemaFiles(schema: SchemaModel): List = when { !schema.anyOf.isNullOrEmpty() || !schema.oneOf.isNullOrEmpty() -> { if (schema.name in hierarchy.anyOfWithoutDiscriminator) { - listOf(generateSealedInterface(schema), generatePolymorphicSerializer(schema)) + listOf( + generateSealedInterface(schema), + generatePolymorphicSerializer(schema), + ) } else { - listOf(generateSealedInterface(schema)) + listOf(generateSealedHierarchy(schema)) } } @@ -213,30 +190,17 @@ internal object ModelGenerator { } /** - * Generates a sealed interface for a oneOf/anyOf schema. - * - anyOf without discriminator: @Serializable(with = XxxSerializer::class) - * - oneOf or anyOf with discriminator: plain @Serializable + @JsonClassDiscriminator + * Generates a sealed class with nested data class subtypes for oneOf or anyOf-with-discriminator schemas. */ - context(hierarchy: HierarchyInfo, modelPackage: ModelPackage) - private fun generateSealedInterface(schema: SchemaModel): FileSpec { - val className = ClassName(modelPackage, schema.name) + context(hierarchy: Hierarchy) + private fun generateSealedHierarchy(schema: SchemaModel): FileSpec { + val className = ClassName(hierarchy.modelPackage, schema.name) - val typeSpec = TypeSpec.interfaceBuilder(className).addModifiers(KModifier.SEALED) - - if (schema.name in hierarchy.anyOfWithoutDiscriminator) { - val serializerClassName = ClassName(modelPackage, "${schema.name}Serializer") - typeSpec.addAnnotation( - AnnotationSpec - .builder(SERIALIZABLE) - .addMember("with = %T::class", serializerClassName) - .build(), - ) - } else { - typeSpec.addAnnotation(SERIALIZABLE) - } + val parentBuilder = TypeSpec.classBuilder(className).addModifiers(KModifier.SEALED) + parentBuilder.addAnnotation(SERIALIZABLE) if (schema.discriminator != null) { - typeSpec.addAnnotation( + parentBuilder.addAnnotation( AnnotationSpec .builder(JSON_CLASS_DISCRIMINATOR) .addMember("%S", schema.discriminator.propertyName) @@ -245,10 +209,19 @@ internal object ModelGenerator { } if (schema.description != null) { - typeSpec.addKdoc("%L", schema.description.sanitizeKdoc()) + parentBuilder.addKdoc("%L", schema.description.sanitizeKdoc()) } - val fileBuilder = FileSpec.builder(className).addType(typeSpec.build()) + // Generate nested subtypes + val variants = hierarchy.sealedHierarchies[schema.name] + variants?.forEach { variantName -> + val variantSchema = hierarchy.schemasById[variantName] + val serialName = schema.resolveSerialName(variantName) + val nestedType = buildNestedVariant(variantSchema, variantName, className, serialName) + parentBuilder.addType(nestedType) + } + + val fileBuilder = FileSpec.builder(className).addType(parentBuilder.build()) if (schema.discriminator != null) { fileBuilder.addAnnotation( @@ -262,23 +235,119 @@ internal object ModelGenerator { return fileBuilder.build() } + /** + * Builds a nested data class TypeSpec for a variant inside a sealed class hierarchy. + */ + context(hierarchy: Hierarchy) + private fun buildNestedVariant( + variantSchema: SchemaModel?, + variantName: String, + parentClassName: ClassName, + serialName: String, + ): TypeSpec { + val variantClassName = parentClassName.nestedClass(variantName) + val builder = TypeSpec.classBuilder(variantClassName).addModifiers(KModifier.DATA) + builder.superclass(parentClassName) + // sealed class has no constructor params, but KotlinPoet requires this call to emit `Shape()` + builder.addSuperclassConstructorParameter("") + builder.addAnnotation(SERIALIZABLE) + builder.addAnnotation(AnnotationSpec.builder(SERIAL_NAME).addMember("%S", serialName).build()) + + if (variantSchema != null) { + buildConstructorAndProperties(variantSchema, builder) + } else { + builder.primaryConstructor(FunSpec.constructorBuilder().build()) + } + + return builder.build() + } + + /** + * Builds primary constructor and data class properties from a schema's property list. + * Shared by [generateDataClass] and [buildNestedVariant]. + */ + context(_: Hierarchy) + private fun buildConstructorAndProperties(schema: SchemaModel, builder: TypeSpec.Builder) { + val sortedProps = schema.properties.sortedBy { prop -> + when { + prop.name in schema.requiredProperties && prop.defaultValue == null -> 1 + prop.defaultValue != null -> 2 + else -> 3 + } + } + + val constructorBuilder = FunSpec.constructorBuilder() + val propertySpecs = sortedProps.map { prop -> + val type = prop.type.toTypeName().copy(nullable = prop.nullable) + val kotlinName = prop.name.toCamelCase() + + val paramBuilder = ParameterSpec.builder(kotlinName, type) + when { + prop.nullable -> paramBuilder.defaultValue(CodeBlock.of("null")) + prop.defaultValue != null -> paramBuilder.defaultValue(formatDefaultValue(prop)) + } + constructorBuilder.addParameter(paramBuilder.build()) + + PropertySpec + .builder(kotlinName, type) + .initializer(kotlinName) + .addAnnotation(AnnotationSpec.builder(SERIAL_NAME).addMember("%S", prop.name).build()) + .apply { prop.description?.let { addKdoc("%L", it.sanitizeKdoc()) } } + .build() + } + + builder.primaryConstructor(constructorBuilder.build()) + builder.addProperties(propertySpecs) + + if (schema.description != null) { + builder.addKdoc("%L", schema.description.sanitizeKdoc()) + } + } + + /** + * Generates a sealed interface for anyOf without discriminator schemas. + * Only used for the JsonContentPolymorphicSerializer pattern. + */ + context(hierarchy: Hierarchy) + private fun generateSealedInterface(schema: SchemaModel): FileSpec { + val className = ClassName(hierarchy.modelPackage, schema.name) + + val typeSpec = TypeSpec.interfaceBuilder(className).addModifiers(KModifier.SEALED) + + val serializerClassName = ClassName(hierarchy.modelPackage, "${schema.name}Serializer") + typeSpec.addAnnotation( + AnnotationSpec + .builder(SERIALIZABLE) + .addMember("with = %T::class", serializerClassName) + .build(), + ) + + if (schema.description != null) { + typeSpec.addKdoc("%L", schema.description) + } + + return FileSpec.builder(className).addType(typeSpec.build()).build() + } + /** * Generates a JsonContentPolymorphicSerializer object for an anyOf schema without discriminator. */ - context(hierarchy: HierarchyInfo, modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun generatePolymorphicSerializer(schema: SchemaModel): FileSpec { - val sealedClassName = ClassName(modelPackage, schema.name) - val serializerClassName = ClassName(modelPackage, "${schema.name}Serializer") - - val schemasById = hierarchy.schemas.associateBy { it.name } + val sealedClassName = ClassName(hierarchy.modelPackage, schema.name) + val serializerClassName = ClassName(hierarchy.modelPackage, "${schema.name}Serializer") val variantProperties = schema.anyOf .orEmpty() .asSequence() .filterIsInstance() .associate { ref -> - val propNames = schemasById[ref.schemaName]?.properties?.map { it.name }?.toSet() ?: emptySet() - ref.schemaName to propNames + val props = hierarchy.schemasById[ref.schemaName] + ?.properties + ?.map { it.name } + ?.toSet() + .orEmpty() + ref.schemaName to props } val allFields = variantProperties.values @@ -321,7 +390,7 @@ internal object ModelGenerator { /** * Builds the body code for selectDeserializer using field-presence heuristics. */ - context(modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun buildSelectDeserializerBody( parentName: String, uniqueFieldsPerVariant: Map, @@ -335,7 +404,7 @@ internal object ModelGenerator { "%SĀ·inĀ·element.%M -> %T.serializer()", uniqueField, JSON_OBJECT_EXT, - ClassName(modelPackage, variantName), + hierarchy[variantName], ) null } else { @@ -363,55 +432,22 @@ internal object ModelGenerator { /** * Generates a data class FileSpec, with superinterfaces and @SerialName resolved from hierarchy. + * Used for: standalone schemas, allOf composed classes, and anyOf-without-discriminator variants. */ - context(hierarchy: HierarchyInfo, modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun generateDataClass(schema: SchemaModel): FileSpec { - val className = ClassName(modelPackage, schema.name) - - val parentEntries = hierarchy.variantParents[schema.name].orEmpty() - val serialName = parentEntries.values.firstOrNull() - val superinterfaces = parentEntries.keys + val className = ClassName(hierarchy.modelPackage, schema.name) - val sortedProps = schema.properties.sortedBy { prop -> - when { - prop.name in schema.requiredProperties && prop.defaultValue == null -> 1 - prop.defaultValue != null -> 2 - else -> 3 - } - } - - val constructorBuilder = FunSpec.constructorBuilder() - val propertySpecs = sortedProps.map { prop -> - val type = prop.type.toTypeName().copy(nullable = prop.nullable) - val kotlinName = prop.name.toCamelCase() - - val paramBuilder = ParameterSpec.builder(kotlinName, type) - - when { - prop.nullable -> paramBuilder.defaultValue(CodeBlock.of("null")) - prop.defaultValue != null -> paramBuilder.defaultValue(formatDefaultValue(prop)) - } - - constructorBuilder.addParameter(paramBuilder.build()) - - val propBuilder = PropertySpec - .builder(kotlinName, type) - .initializer(kotlinName) - .addAnnotation( - AnnotationSpec - .builder(SERIAL_NAME) - .addMember("%S", prop.name) - .build(), - ).apply { prop.description?.let { addKdoc("%L", it.sanitizeKdoc()) } } - - propBuilder.build() + // For anyOf-without-discriminator variants: find parent interfaces and serialName + val parentNames = hierarchy.anyOfParents[schema.name].orEmpty() + val superinterfaces = parentNames.map { ClassName(hierarchy.modelPackage, it) } + val serialName = parentNames.firstOrNull()?.let { parentName -> + hierarchy.schemasById[parentName]?.resolveSerialName(schema.name) } val typeSpec = TypeSpec .classBuilder(className) .addModifiers(KModifier.DATA) - .primaryConstructor(constructorBuilder.build()) - .addProperties(propertySpecs) .addAnnotation(SERIALIZABLE) .addSuperinterfaces(superinterfaces) @@ -424,9 +460,7 @@ internal object ModelGenerator { ) } - if (schema.description != null) { - typeSpec.addKdoc("%L", schema.description.sanitizeKdoc()) - } + buildConstructorAndProperties(schema, typeSpec) val fileBuilder = FileSpec.builder(className).addType(typeSpec.build()) @@ -453,7 +487,7 @@ internal object ModelGenerator { * Formats a default value from a PropertyModel for use in KotlinPoet ParameterSpec.defaultValue(). */ - context(modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun formatDefaultValue(prop: PropertyModel): CodeBlock = when (prop.type) { is TypeRef.Primitive -> { when (prop.type.type) { @@ -492,7 +526,7 @@ internal object ModelGenerator { is TypeRef.Reference -> { val constantName = prop.defaultValue.toString().toEnumConstantName() - CodeBlock.of("%T.%L", ClassName(modelPackage, prop.type.schemaName), constantName) + CodeBlock.of("%T.%L", ClassName(hierarchy.modelPackage, prop.type.schemaName), constantName) } else -> { @@ -500,21 +534,9 @@ internal object ModelGenerator { } } - /** - * Resolves the @SerialName value for a variant within a oneOf schema. - */ - private fun resolveSerialName(parentSchema: SchemaModel, variantSchemaName: String): String = - parentSchema.discriminator - ?.mapping - .orEmpty() - .firstNotNullOfOrNull { (serialName, refPath) -> - serialName.takeIf { refPath.removePrefix("#/components/schemas/") == variantSchemaName } - } - ?: variantSchemaName - - context(modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun generateEnumClass(enum: EnumModel): FileSpec { - val className = ClassName(modelPackage, enum.name) + val className = ClassName(hierarchy.modelPackage, enum.name) val typeSpec = TypeSpec.enumBuilder(className).addAnnotation(SERIALIZABLE) @@ -569,6 +591,21 @@ internal object ModelGenerator { return visited.toList() } + context(hierarchy: Hierarchy) + private fun generateNestedInlineClass(schema: SchemaModel): FileSpec { + val flatName = schema.name.toInlinedName() + val className = ClassName(hierarchy.modelPackage, flatName) + + val typeSpec = TypeSpec + .classBuilder(className) + .addModifiers(KModifier.DATA) + .addAnnotation(SERIALIZABLE) + + buildConstructorAndProperties(schema, typeSpec) + + return FileSpec.builder(className).addType(typeSpec.build()).build() + } + private val SchemaModel.isPrimitiveOnly: Boolean get() = properties.isEmpty() && allOf == null && oneOf == null && anyOf == null @@ -637,9 +674,9 @@ internal object ModelGenerator { .build() } - context(modelPackage: ModelPackage) + context(hierarchy: Hierarchy) private fun generateTypeAlias(schema: SchemaModel, primitiveType: TypeName): FileSpec { - val className = ClassName(modelPackage, schema.name) + val className = ClassName(hierarchy.modelPackage, schema.name) val typeAlias = TypeAliasSpec.builder(schema.name, primitiveType) diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/gen/shared/SerializersModuleGenerator.kt b/core/src/main/kotlin/com/avsystem/justworks/core/gen/shared/SerializersModuleGenerator.kt index c4db460..ee22018 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/gen/shared/SerializersModuleGenerator.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/gen/shared/SerializersModuleGenerator.kt @@ -1,12 +1,11 @@ package com.avsystem.justworks.core.gen.shared import com.avsystem.justworks.core.gen.GENERATED_SERIALIZERS_MODULE -import com.avsystem.justworks.core.gen.ModelPackage +import com.avsystem.justworks.core.gen.Hierarchy import com.avsystem.justworks.core.gen.POLYMORPHIC_FUN import com.avsystem.justworks.core.gen.SERIALIZERS_MODULE import com.avsystem.justworks.core.gen.SUBCLASS_FUN import com.avsystem.justworks.core.gen.invoke -import com.avsystem.justworks.core.gen.model.ModelGenerator import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FileSpec @@ -24,7 +23,7 @@ internal object SerializersModuleGenerator { * Returns null if the hierarchy has no sealed types to register. */ - context(hierarchy: ModelGenerator.HierarchyInfo, modelPackage: ModelPackage) + context(hierarchy: Hierarchy) fun generate(): FileSpec? { // anyOf hierarchies without a discriminator use JsonContentPolymorphicSerializer // with custom deserialization logic, so they don't need SerializersModule registration. @@ -36,10 +35,10 @@ internal object SerializersModuleGenerator { val code = CodeBlock.builder().beginControlFlow("%T", SERIALIZERS_MODULE) for ((parent, variants) in discriminatorHierarchies) { - val parentClass = ClassName(modelPackage, parent) + val parentClass = ClassName(hierarchy.modelPackage, parent) code.beginControlFlow("%M(%T::class)", POLYMORPHIC_FUN, parentClass) for (variant in variants) { - val variantClass = ClassName(modelPackage, variant) + val variantClass = parentClass.nestedClass(variant) code.addStatement("%M(%T::class)", SUBCLASS_FUN, variantClass) } code.endControlFlow() @@ -53,7 +52,7 @@ internal object SerializersModuleGenerator { .build() return FileSpec - .builder(modelPackage.name, SERIALIZERS_MODULE.simpleName) + .builder(hierarchy.modelPackage.name, SERIALIZERS_MODULE.simpleName) .addProperty(prop) .build() } diff --git a/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt b/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt index 4ae1c88..131443e 100644 --- a/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt +++ b/core/src/main/kotlin/com/avsystem/justworks/core/parser/SpecParser.kt @@ -12,6 +12,7 @@ import arrow.core.raise.iorNel import arrow.core.raise.nullable import arrow.core.toNonEmptyListOrNull import com.avsystem.justworks.core.Issue +import com.avsystem.justworks.core.SCHEMA_PREFIX import com.avsystem.justworks.core.Warnings import com.avsystem.justworks.core.model.ApiSpec import com.avsystem.justworks.core.model.ContentType @@ -463,8 +464,6 @@ object SpecParser { private fun String.toPascalCase(): String = split("-", "_", ".").joinToString("") { part -> part.replaceFirstChar { it.uppercase() } } - private const val SCHEMA_PREFIX = "#/components/schemas/" - private val STRING_FORMAT_MAP = mapOf( "byte" to TypeRef.Primitive(PrimitiveType.BYTE_ARRAY), "binary" to TypeRef.Primitive(PrimitiveType.BYTE_ARRAY), diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/MemoTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/MemoTest.kt new file mode 100644 index 0000000..1098606 --- /dev/null +++ b/core/src/test/kotlin/com/avsystem/justworks/core/MemoTest.kt @@ -0,0 +1,76 @@ +package com.avsystem.justworks.core + +import java.util.concurrent.atomic.AtomicInteger +import kotlin.concurrent.thread +import kotlin.test.Test +import kotlin.test.assertEquals + +class MemoTest { + @Test + fun `Memoized should compute only once`() { + val counter = AtomicInteger(0) + val memo = Memo { counter.incrementAndGet() } + + assertEquals(1, memo.getValue(null, null)) + assertEquals(1, memo.getValue(null, null)) + assertEquals(1, counter.get()) + } + + @Test + fun `Memoized reset should force recompute`() { + val counter = AtomicInteger(0) + val memo = Memo { counter.incrementAndGet() } + + assertEquals(1, memo.getValue(null, null)) + memo.reset() + assertEquals(2, memo.getValue(null, null)) + assertEquals(2, counter.get()) + } + + @Test + fun `Memoized should be thread safe`() { + val counter = AtomicInteger(0) + val memo = Memo { + Thread.sleep(10) + counter.incrementAndGet() + } + + val threads = List(10) { + thread { + memo.getValue(null, null) + } + } + threads.forEach { it.join() } + + assertEquals(1, counter.get()) + } + + @Test + fun `CacheGroup should reset all memoized instances`() { + val counter1 = AtomicInteger(0) + val counter2 = AtomicInteger(0) + val memoScope = MemoScope() + + val m1 = memoized(memoScope) { counter1.incrementAndGet() } + val m2 = memoized(memoScope) { counter2.incrementAndGet() } + + assertEquals(1, m1.getValue(null, null)) + assertEquals(1, m2.getValue(null, null)) + + memoScope.reset() + + assertEquals(2, m1.getValue(null, null)) + assertEquals(2, m2.getValue(null, null)) + } + + @Test + fun `memoized helper should add to CacheGroup`() { + val memoScope = MemoScope() + val counter = AtomicInteger(0) + val m = memoized(memoScope) { counter.incrementAndGet() } + + assertEquals(1, m.getValue(null, null)) + memoScope.reset() + assertEquals(2, m.getValue(null, null)) + } +} diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt index 6e353b2..b2d8c01 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ClientGeneratorTest.kt @@ -27,10 +27,15 @@ class ClientGeneratorTest { private val apiPackage = "com.example.api" private val modelPackage = "com.example.model" - private fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean = false): List = - context(ModelPackage(modelPackage), ApiPackage(apiPackage)) { - ClientGenerator.generate(spec, hasPolymorphicTypes, NameRegistry()) - } + private fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean = false): List = context( + Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(spec.schemas) + }, + ApiPackage(apiPackage), + NameRegistry(), + ) { + ClientGenerator.generate(spec, hasPolymorphicTypes) + } private fun spec(vararg endpoints: Endpoint) = ApiSpec( title = "Test", diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt index cc618b1..f07b7e3 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/IntegrationTest.kt @@ -41,16 +41,25 @@ class IntegrationTest { } private fun generateModel(spec: ApiSpec): List = - context(ModelPackage(modelPackage)) { ModelGenerator.generate(spec, NameRegistry()) } + context(Hierarchy(ModelPackage(modelPackage)).apply { addSchemas(spec.schemas) }, NameRegistry()) { + ModelGenerator.generate(spec) + } private fun generateModelWithResolvedSpec(spec: ApiSpec): ModelGenerator.GenerateResult = - context(ModelPackage(modelPackage)) { ModelGenerator.generateWithResolvedSpec(spec, NameRegistry()) } - - private fun generateClient(spec: ApiSpec, hasPolymorphicTypes: Boolean = false): List = - context(ModelPackage(modelPackage), ApiPackage(apiPackage)) { - ClientGenerator.generate(spec, hasPolymorphicTypes, NameRegistry()) + context(Hierarchy(ModelPackage(modelPackage)).apply { addSchemas(spec.schemas) }, NameRegistry()) { + ModelGenerator.generateWithResolvedSpec(spec) } + private fun generateClient(spec: ApiSpec, hasPolymorphicTypes: Boolean = false): List = context( + Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(spec.schemas) + }, + ApiPackage(apiPackage), + NameRegistry(), + ) { + ClientGenerator.generate(spec, hasPolymorphicTypes) + } + @Test fun `real-world specs generate compilable enum code without class body conflicts`() { for (fixture in SPEC_FIXTURES) { diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt index c7c083a..c4b2c20 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorPolymorphicTest.kt @@ -8,6 +8,7 @@ import com.avsystem.justworks.core.model.PrimitiveType import com.avsystem.justworks.core.model.PropertyModel import com.avsystem.justworks.core.model.SchemaModel import com.avsystem.justworks.core.model.TypeRef +import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.TypeSpec import kotlin.test.Test @@ -18,8 +19,13 @@ import kotlin.test.assertTrue class ModelGeneratorPolymorphicTest { private val modelPackage = "com.example.model" - private fun generate(spec: ApiSpec) = context(ModelPackage(modelPackage)) { - ModelGenerator.generate(spec, NameRegistry()) + private fun generate(spec: ApiSpec) = context( + Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(spec.schemas) + }, + NameRegistry(), + ) { + ModelGenerator.generate(spec) } private fun spec(schemas: List = emptyList(), enums: List = emptyList()) = ApiSpec( @@ -49,9 +55,21 @@ class ModelGeneratorPolymorphicTest { discriminator = discriminator, ) + /** + * Recursively searches for a TypeSpec by name in files and nested types. + */ private fun findType(files: List, name: String): TypeSpec { + fun searchIn(types: List): TypeSpec? { + for (type in types) { + if (type.name == name) return type + val nested = searchIn(type.typeSpecs) + if (nested != null) return nested + } + return null + } + for (file in files) { - val found = file.members.filterIsInstance().find { it.name == name } + val found = searchIn(file.members.filterIsInstance()) if (found != null) return found } throw AssertionError("TypeSpec '$name' not found in generated files") @@ -67,10 +85,10 @@ class ModelGeneratorPolymorphicTest { throw AssertionError("FileSpec containing '$typeName' not found") } - // -- POLY-01: Sealed interface from oneOf -- + // -- POLY-01: Sealed class from oneOf (nested subtypes) -- @Test - fun `oneOf schema generates sealed interface with SEALED modifier`() { + fun `oneOf schema generates sealed class with SEALED modifier`() { val shapeSchema = schema( name = "Shape", @@ -93,7 +111,64 @@ class ModelGeneratorPolymorphicTest { val shapeType = findType(files, "Shape") assertTrue(KModifier.SEALED in shapeType.modifiers, "Expected SEALED modifier on Shape") - assertEquals(TypeSpec.Kind.INTERFACE, shapeType.kind, "Expected INTERFACE kind") + assertEquals(TypeSpec.Kind.CLASS, shapeType.kind, "Expected CLASS kind (not INTERFACE)") + } + + @Test + fun `oneOf subtypes are nested inside parent sealed class`() { + val shapeSchema = + schema( + name = "Shape", + oneOf = listOf(TypeRef.Reference("Circle"), TypeRef.Reference("Square")), + ) + val circleSchema = + schema( + name = "Circle", + properties = listOf(PropertyModel("radius", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), + requiredProperties = setOf("radius"), + ) + val squareSchema = + schema( + name = "Square", + properties = listOf(PropertyModel("sideLength", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), + requiredProperties = setOf("sideLength"), + ) + + val files = generate(spec(schemas = listOf(shapeSchema, circleSchema, squareSchema))) + val shapeType = findType(files, "Shape") + + val nestedNames = shapeType.typeSpecs.map { it.name } + assertTrue("Circle" in nestedNames, "Circle should be nested inside Shape. Nested: $nestedNames") + assertTrue("Square" in nestedNames, "Square should be nested inside Shape. Nested: $nestedNames") + } + + @Test + fun `oneOf hierarchy produces single file`() { + val shapeSchema = + schema( + name = "Shape", + oneOf = listOf(TypeRef.Reference("Circle"), TypeRef.Reference("Square")), + ) + val circleSchema = + schema( + name = "Circle", + properties = listOf(PropertyModel("radius", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), + requiredProperties = setOf("radius"), + ) + val squareSchema = + schema( + name = "Square", + properties = listOf(PropertyModel("sideLength", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), + requiredProperties = setOf("sideLength"), + ) + + val files = generate(spec(schemas = listOf(shapeSchema, circleSchema, squareSchema))) + + // No separate Circle.kt or Square.kt files + val fileNames = files.map { it.name } + assertTrue("Circle" !in fileNames, "Circle should NOT have separate file. Files: $fileNames") + assertTrue("Square" !in fileNames, "Square should NOT have separate file. Files: $fileNames") + assertTrue("Shape" in fileNames, "Shape file should exist. Files: $fileNames") } @Test @@ -110,13 +185,13 @@ class ModelGeneratorPolymorphicTest { val shapeType = findType(files, "Shape") val annotations = shapeType.annotations.map { it.typeName.toString() } - assertTrue("kotlinx.serialization.Serializable" in annotations, "Expected @Serializable on sealed interface") + assertTrue("kotlinx.serialization.Serializable" in annotations, "Expected @Serializable on sealed class") } - // -- POLY-02: Variant data classes implement sealed interface -- + // -- POLY-02: Variant subtypes extend sealed class -- @Test - fun `variant data class implements sealed interface`() { + fun `variant data class extends sealed class via superclass`() { val shapeSchema = schema( name = "Shape", @@ -132,10 +207,11 @@ class ModelGeneratorPolymorphicTest { val files = generate(spec(schemas = listOf(shapeSchema, circleSchema))) val circleType = findType(files, "Circle") - val superinterfaces = circleType.superinterfaces.keys.map { it.toString() } + // Should use superclass (not superinterfaces) since parent is sealed class + val superclass = circleType.superclass.toString() assertTrue( - "$modelPackage.Shape" in superinterfaces, - "Circle should implement Shape. Superinterfaces: $superinterfaces", + "$modelPackage.Shape" in superclass, + "Circle should extend Shape as superclass. Superclass: $superclass", ) } @@ -167,6 +243,30 @@ class ModelGeneratorPolymorphicTest { ) } + @Test + fun `variant data class has Serializable annotation`() { + val shapeSchema = + schema( + name = "Shape", + oneOf = listOf(TypeRef.Reference("Circle")), + ) + val circleSchema = + schema( + name = "Circle", + properties = listOf(PropertyModel("radius", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), + requiredProperties = setOf("radius"), + ) + + val files = generate(spec(schemas = listOf(shapeSchema, circleSchema))) + val circleType = findType(files, "Circle") + + val annotations = circleType.annotations.map { it.typeName.toString() } + assertTrue( + "kotlinx.serialization.Serializable" in annotations, + "Nested variant should have @Serializable. Annotations: $annotations", + ) + } + // -- POLY-03: Discriminator -- @Test @@ -266,8 +366,6 @@ class ModelGeneratorPolymorphicTest { @Test fun `allOf schema produces data class with merged properties`() { - // SpecParser merges allOf properties before ModelGenerator sees them. - // So ExtendedDog already has all properties (from Dog + inline) in its SchemaModel. val dogSchema = schema( name = "Dog", @@ -295,7 +393,6 @@ class ModelGeneratorPolymorphicTest { val extendedDogType = findType(files, "ExtendedDog") val constructor = assertNotNull(extendedDogType.primaryConstructor, "Expected primary constructor") - // Should have all merged properties: name, breed from Dog + tricks from inline val paramNames = constructor.parameters.map { it.name } assertTrue("name" in paramNames, "Expected 'name' from Dog. Params: $paramNames") assertTrue("breed" in paramNames, "Expected 'breed' from Dog. Params: $paramNames") @@ -342,8 +439,7 @@ class ModelGeneratorPolymorphicTest { // -- POLY-07: oneOf with wrapper objects -- @Test - fun `oneOf with wrapper objects generates sealed interface with JsonClassDiscriminator`() { - // Create wrapper schema like AWS CloudControl's NetworkMeshDevice + fun `oneOf with wrapper objects generates sealed class with JsonClassDiscriminator`() { val extenderPropsSchema = schema( name = "ExtenderDeviceProperties", @@ -357,8 +453,6 @@ class ModelGeneratorPolymorphicTest { requiredProperties = setOf("macAddress"), ) - // Parent schema with oneOf pointing to wrapper variants - // Note: This test verifies the SpecParser has already unwrapped, so we pass the unwrapped form val networkMeshSchema = schema( name = "NetworkMeshDevice", @@ -383,8 +477,8 @@ class ModelGeneratorPolymorphicTest { ) val networkMeshType = findType(files, "NetworkMeshDevice") - // Verify sealed interface with discriminator assertTrue(KModifier.SEALED in networkMeshType.modifiers) + assertEquals(TypeSpec.Kind.CLASS, networkMeshType.kind, "Expected CLASS kind for discriminated oneOf") val discriminatorAnnotation = networkMeshType.annotations.find { it.typeName.toString() == "kotlinx.serialization.json.JsonClassDiscriminator" @@ -394,8 +488,7 @@ class ModelGeneratorPolymorphicTest { } @Test - fun `oneOf with wrapper objects generates correct SerialName on variants`() { - // Same setup as above + fun `oneOf with wrapper objects generates correct SerialName on nested variants`() { val extenderPropsSchema = schema( name = "ExtenderDeviceProperties", @@ -417,7 +510,6 @@ class ModelGeneratorPolymorphicTest { val files = generate(spec(schemas = listOf(networkMeshSchema, extenderPropsSchema))) val extenderType = findType(files, "ExtenderDeviceProperties") - // Verify @SerialName uses wrapper property name val serialNameAnnotation = extenderType.annotations.find { it.typeName.toString() == "kotlinx.serialization.SerialName" @@ -429,7 +521,7 @@ class ModelGeneratorPolymorphicTest { ) } - // -- POLY-08: anyOf without discriminator -> JsonContentPolymorphicSerializer -- + // -- POLY-08: anyOf without discriminator -> JsonContentPolymorphicSerializer (UNCHANGED) -- @Test fun `anyOf without discriminator generates sealed interface with Serializable(with) annotation`() { @@ -451,6 +543,9 @@ class ModelGeneratorPolymorphicTest { val files = generate(spec(schemas = listOf(unionSchema, creditCardSchema, bankTransferSchema))) val paymentType = findType(files, "Payment") + // anyOf without discriminator still uses sealed interface + assertEquals(TypeSpec.Kind.INTERFACE, paymentType.kind, "anyOf without discriminator should remain INTERFACE") + val serializableAnnotation = paymentType.annotations.find { it.typeName.toString() == "kotlinx.serialization.Serializable" } @@ -528,7 +623,6 @@ class ModelGeneratorPolymorphicTest { name = "Payment", anyOf = listOf(TypeRef.Reference("TypeA"), TypeRef.Reference("TypeB")), ) - // Both variants share the same field "amount" - no unique fields val typeASchema = schema( name = "TypeA", properties = listOf(PropertyModel("amount", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false)), @@ -580,8 +674,7 @@ class ModelGeneratorPolymorphicTest { } @Test - fun `anyOf with discriminator NOT affected by JsonContentPolymorphicSerializer path`() { - // Ensure the discriminator-present anyOf still uses the old SerializersModule path + fun `anyOf with discriminator generates sealed class with nested subtypes`() { val shapeSchema = schema( name = "Shape", anyOf = listOf(TypeRef.Reference("Circle"), TypeRef.Reference("Square")), @@ -596,6 +689,9 @@ class ModelGeneratorPolymorphicTest { val files = generate(spec(schemas = listOf(shapeSchema, circleSchema, squareSchema))) val shapeType = findType(files, "Shape") + // Should be sealed class (not interface) for anyOf with discriminator + assertEquals(TypeSpec.Kind.CLASS, shapeType.kind, "Discriminated anyOf should be sealed CLASS") + // Should have plain @Serializable, NOT @Serializable(with = ...) val serializableAnnotation = shapeType.annotations.find { it.typeName.toString() == "kotlinx.serialization.Serializable" @@ -606,6 +702,11 @@ class ModelGeneratorPolymorphicTest { "Discriminated anyOf should use plain @Serializable, not @Serializable(with = ...). Members: ${serializableAnnotation.members}", ) + // Subtypes should be nested + val nestedNames = shapeType.typeSpecs.map { it.name } + assertTrue("Circle" in nestedNames, "Circle should be nested inside Shape. Nested: $nestedNames") + assertTrue("Square" in nestedNames, "Square should be nested inside Shape. Nested: $nestedNames") + // ShapeSerializer should NOT be generated val serializerTypes = files.flatMap { it.members.filterIsInstance() } val shapeSerializerType = serializerTypes.find { it.name == "ShapeSerializer" } @@ -619,7 +720,7 @@ class ModelGeneratorPolymorphicTest { // -- CEM-01: boolean discriminator names (KotlinPoet handles escaping) -- @Test - fun `boolean discriminator names produce valid data classes`() { + fun `boolean discriminator names produce valid nested data classes`() { val deviceStatusSchema = schema( name = "DeviceStatus", oneOf = listOf( @@ -659,21 +760,21 @@ class ModelGeneratorPolymorphicTest { val falseType = findType(files, "false") assertTrue(KModifier.DATA in falseType.modifiers, "'false' should be data class") - // Both implement DeviceStatus sealed interface - val trueSuperinterfaces = trueType.superinterfaces.keys.map { it.toString() } + // Both should extend DeviceStatus sealed class as superclass + val trueSuperclass = trueType.superclass.toString() assertTrue( - "$modelPackage.DeviceStatus" in trueSuperinterfaces, - "'true' should implement DeviceStatus. Superinterfaces: $trueSuperinterfaces", + "$modelPackage.DeviceStatus" in trueSuperclass, + "'true' should extend DeviceStatus. Superclass: $trueSuperclass", ) - val falseSuperinterfaces = falseType.superinterfaces.keys.map { it.toString() } + val falseSuperclass = falseType.superclass.toString() assertTrue( - "$modelPackage.DeviceStatus" in falseSuperinterfaces, - "'false' should implement DeviceStatus. Superinterfaces: $falseSuperinterfaces", + "$modelPackage.DeviceStatus" in falseSuperclass, + "'false' should extend DeviceStatus. Superclass: $falseSuperclass", ) } @Test - fun `all oneOf variant schemas generate data classes even with many subtypes`() { + fun `all oneOf variant schemas generate nested data classes even with many subtypes`() { val variantNames = listOf( "ExtenderDevice", "EthernetDevice", @@ -706,21 +807,19 @@ class ModelGeneratorPolymorphicTest { spec(schemas = listOf(networkMeshSchema) + variantSchemas), ) - // All 6 variants generated + // Parent should be sealed class + val networkMeshType = findType(files, "NetworkMeshDevice") + assertEquals(TypeSpec.Kind.CLASS, networkMeshType.kind, "Expected sealed CLASS") + + // All 6 variants nested inside parent + val nestedNames = networkMeshType.typeSpecs.map { it.name } for (name in variantNames) { - val variantType = findType(files, name) - assertTrue( - KModifier.DATA in variantType.modifiers, - "$name should be a data class", - ) - val superinterfaces = variantType.superinterfaces.keys.map { it.toString() } - assertTrue( - "$modelPackage.NetworkMeshDevice" in superinterfaces, - "$name should implement NetworkMeshDevice. Superinterfaces: $superinterfaces", - ) + assertTrue(name in nestedNames, "$name should be nested inside NetworkMeshDevice. Nested: $nestedNames") + val variantType = networkMeshType.typeSpecs.find { it.name == name }!! + assertTrue(KModifier.DATA in variantType.modifiers, "$name should be a data class") } - // SerializersModule contains all variants + // SerializersModule contains all variants with nested references val serializersModuleFile = files.find { it.name == "SerializersModule" } assertNotNull(serializersModuleFile, "SerializersModule file should be generated") val moduleCode = serializersModuleFile.toString() @@ -733,7 +832,7 @@ class ModelGeneratorPolymorphicTest { } @Test - fun `SerializersModule includes boolean variant names`() { + fun `SerializersModule includes boolean variant names with nested references`() { val deviceStatusSchema = schema( name = "DeviceStatus", oneOf = listOf( @@ -783,7 +882,7 @@ class ModelGeneratorPolymorphicTest { // -- POLY-06: allOf with sealed parent -- @Test - fun `allOf referencing oneOf parent adds superinterface`() { + fun `allOf referencing oneOf parent - variant is nested in parent`() { val petSchema = schema( name = "Pet", @@ -803,10 +902,49 @@ class ModelGeneratorPolymorphicTest { val files = generate(spec(schemas = listOf(petSchema, dogSchema))) val dogType = findType(files, "Dog") - val superinterfaces = dogType.superinterfaces.keys.map { it.toString() } + // Dog should extend Pet (sealed class) as superclass + val superclass = dogType.superclass.toString() assertTrue( - "$modelPackage.Pet" in superinterfaces, - "Dog should have Pet as superinterface. Superinterfaces: $superinterfaces", + "$modelPackage.Pet" in superclass, + "Dog should have Pet as superclass. Superclass: $superclass", + ) + } + + // -- toTypeName with Hierarchy -- + + @Test + fun `toTypeName resolves variant to nested ClassName via Hierarchy`() { + val shapeSchema = schema( + name = "Shape", + oneOf = listOf(TypeRef.Reference("Circle"), TypeRef.Reference("Square")), + ) + val circleSchema = schema(name = "Circle") + val squareSchema = schema(name = "Square") + val hierarchy = Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(listOf(shapeSchema, circleSchema, squareSchema)) + } + + val result = context(hierarchy) { + TypeRef.Reference("Circle").toTypeName() + } + assertEquals( + ClassName(modelPackage, "Shape", "Circle"), + result, + "Should resolve Circle to Shape.Circle via Hierarchy", + ) + } + + @Test + fun `toTypeName falls back to flat ClassName for non-variant`() { + val hierarchy = Hierarchy(ModelPackage(modelPackage)) + + val result = context(hierarchy) { + TypeRef.Reference("Circle").toTypeName() + } + assertEquals( + ClassName(modelPackage, "Circle"), + result, + "Should resolve Circle to flat ClassName without hierarchy entry", ) } } diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorRegressionTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorRegressionTest.kt new file mode 100644 index 0000000..de95681 --- /dev/null +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorRegressionTest.kt @@ -0,0 +1,75 @@ +package com.avsystem.justworks.core.gen + +import com.avsystem.justworks.core.gen.model.ModelGenerator +import com.avsystem.justworks.core.model.ApiSpec +import com.avsystem.justworks.core.model.PrimitiveType +import com.avsystem.justworks.core.model.PropertyModel +import com.avsystem.justworks.core.model.SchemaModel +import com.avsystem.justworks.core.model.TypeRef +import kotlin.test.Test + +class ModelGeneratorRegressionTest { + private val modelPackage = "com.example.model" + + private fun generate(spec: ApiSpec) = context( + Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(spec.schemas) + }, + NameRegistry(), + ) { + ModelGenerator.generate(spec) + } + + private fun spec(schemas: List = emptyList()) = ApiSpec( + title = "Test", + version = "1.0", + endpoints = emptyList(), + schemas = schemas, + enums = emptyList(), + ) + + private fun schema( + name: String, + properties: List = emptyList(), + requiredProperties: Set = emptySet(), + oneOf: List? = null, + ) = SchemaModel( + name = name, + description = null, + properties = properties, + requiredProperties = requiredProperties, + allOf = null, + oneOf = oneOf, + anyOf = null, + discriminator = null, + ) + + @Test + fun `reproduce issue where variant with inline property causes crash`() { + val shapeSchema = schema( + name = "Shape", + oneOf = listOf(TypeRef.Reference("Circle")), + ) + val circleSchema = schema( + name = "Circle", + properties = listOf( + PropertyModel( + name = "config", + type = TypeRef.Inline( + properties = listOf( + PropertyModel("radius", TypeRef.Primitive(PrimitiveType.DOUBLE), null, false), + ), + requiredProperties = setOf("radius"), + contextHint = "Circle_config", + ), + description = null, + nullable = false, + ), + ), + requiredProperties = setOf("config"), + ) + + // This should not crash + generate(spec(schemas = listOf(shapeSchema, circleSchema))) + } +} diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt index 417d3d1..5dd0536 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/ModelGeneratorTest.kt @@ -23,8 +23,13 @@ import kotlin.test.assertTrue class ModelGeneratorTest { private val modelPackage = "com.example.model" - private fun generate(spec: ApiSpec) = context(ModelPackage(modelPackage)) { - ModelGenerator.generate(spec, NameRegistry()) + private fun generate(spec: ApiSpec) = context( + Hierarchy(ModelPackage(modelPackage)).apply { + addSchemas(spec.schemas) + }, + NameRegistry(), + ) { + ModelGenerator.generate(spec) } private fun spec(schemas: List = emptyList(), enums: List = emptyList()) = ApiSpec( diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/SerializersModuleGeneratorTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/SerializersModuleGeneratorTest.kt index e2e7fdb..21233bd 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/SerializersModuleGeneratorTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/SerializersModuleGeneratorTest.kt @@ -1,7 +1,8 @@ package com.avsystem.justworks.core.gen -import com.avsystem.justworks.core.gen.model.ModelGenerator import com.avsystem.justworks.core.gen.shared.SerializersModuleGenerator +import com.avsystem.justworks.core.model.SchemaModel +import com.avsystem.justworks.core.model.TypeRef import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.PropertySpec import kotlin.test.Test @@ -12,23 +13,44 @@ import kotlin.test.assertTrue class SerializersModuleGeneratorTest { private val modelPackage = ModelPackage("com.example.model") - private fun hierarchyInfo( + private fun emptySchema( + name: String, + oneOf: List? = null, + anyOf: List? = null, + ) = SchemaModel( + name = name, + description = null, + properties = emptyList(), + requiredProperties = emptySet(), + oneOf = oneOf, + anyOf = anyOf, + allOf = null, + discriminator = null, + ) + + private fun buildHierarchy( sealedHierarchies: Map>, anyOfWithoutDiscriminator: Set = emptySet(), - ) = ModelGenerator.HierarchyInfo( - sealedHierarchies = sealedHierarchies, - variantParents = emptyMap(), - anyOfWithoutDiscriminator = anyOfWithoutDiscriminator, - schemas = emptyList(), - ) + ): Hierarchy { + val schemas = sealedHierarchies.flatMap { (parent, variants) -> + val refs = variants.map { TypeRef.Reference(it) } + val parentSchema = emptySchema( + name = parent, + oneOf = if (parent !in anyOfWithoutDiscriminator) refs else null, + anyOf = if (parent in anyOfWithoutDiscriminator) refs else null, + ) + val variantSchemas = variants.map { emptySchema(it) } + listOf(parentSchema) + variantSchemas + } + return Hierarchy(modelPackage).apply { addSchemas(schemas) } + } - private fun generate(info: ModelGenerator.HierarchyInfo): FileSpec? = - context(info, modelPackage) { SerializersModuleGenerator.generate() } + private fun generate(hierarchy: Hierarchy): FileSpec? = context(hierarchy) { SerializersModuleGenerator.generate() } @Test fun `generates SerializersModule with polymorphic registration`() { val hierarchies = mapOf("Shape" to listOf("Circle", "Square")) - val fileSpec = generate(hierarchyInfo(hierarchies)) + val fileSpec = generate(buildHierarchy(hierarchies)) assertNotNull(fileSpec, "Should generate a FileSpec for non-empty hierarchies") @@ -47,7 +69,7 @@ class SerializersModuleGeneratorTest { "Shape" to listOf("Circle", "Square"), "Animal" to listOf("Cat", "Dog"), ) - val fileSpec = generate(hierarchyInfo(hierarchies)) + val fileSpec = generate(buildHierarchy(hierarchies)) assertNotNull(fileSpec) val initializer = @@ -66,7 +88,7 @@ class SerializersModuleGeneratorTest { @Test fun `returns null for empty hierarchies`() { - val result = generate(hierarchyInfo(emptyMap())) + val result = generate(buildHierarchy(emptyMap>())) assertNull(result, "Should return null for empty hierarchies") } @@ -76,7 +98,7 @@ class SerializersModuleGeneratorTest { "Shape" to listOf("Circle", "Square"), "Pet" to listOf("Cat", "Dog"), ) - val info = hierarchyInfo(hierarchies, anyOfWithoutDiscriminator = setOf("Pet")) + val info = buildHierarchy(hierarchies, anyOfWithoutDiscriminator = setOf("Pet")) val fileSpec = generate(info) assertNotNull(fileSpec) @@ -93,7 +115,7 @@ class SerializersModuleGeneratorTest { @Test fun `returns null when all hierarchies are anyOf without discriminator`() { val hierarchies = mapOf("Pet" to listOf("Cat", "Dog")) - val info = hierarchyInfo(hierarchies, anyOfWithoutDiscriminator = setOf("Pet")) + val info = buildHierarchy(hierarchies, anyOfWithoutDiscriminator = setOf("Pet")) val result = generate(info) assertNull(result, "Should return null when only non-discriminator anyOf hierarchies exist") diff --git a/core/src/test/kotlin/com/avsystem/justworks/core/gen/TypeMappingTest.kt b/core/src/test/kotlin/com/avsystem/justworks/core/gen/TypeMappingTest.kt index a56be4f..5c78420 100644 --- a/core/src/test/kotlin/com/avsystem/justworks/core/gen/TypeMappingTest.kt +++ b/core/src/test/kotlin/com/avsystem/justworks/core/gen/TypeMappingTest.kt @@ -10,8 +10,9 @@ import kotlin.test.assertFailsWith class TypeMappingTest { private val pkg = ModelPackage("com.example.model") + private val hierarchy = Hierarchy(modelPackage = pkg) - private fun map(typeRef: TypeRef): TypeName = context(pkg) { + private fun map(typeRef: TypeRef): TypeName = context(hierarchy) { typeRef.toTypeName() } diff --git a/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt b/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt index c95eeb3..2e71d7c 100644 --- a/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt +++ b/plugin/src/functionalTest/kotlin/com/avsystem/justworks/gradle/JustworksPluginFunctionalTest.kt @@ -300,18 +300,18 @@ class JustworksPluginFunctionalTest { val outputDir = projectDir.resolve("build/generated/justworks/main/com/example/model") - // Sealed interface file exists + // Sealed class file exists with nested subtypes val shapeFile = outputDir.resolve("Shape.kt") assertTrue(shapeFile.exists(), "Shape.kt should exist") val shapeContent = shapeFile.readText() - assertTrue(shapeContent.contains("sealed interface"), "Shape.kt should contain sealed interface") + assertTrue(shapeContent.contains("sealed class"), "Shape.kt should contain sealed class") assertTrue(shapeContent.contains("JsonClassDiscriminator"), "Shape.kt should contain @JsonClassDiscriminator") - // Variant data class file exists and implements sealed interface + // Variant subtypes are nested inside Shape.kt, no separate files val circleFile = outputDir.resolve("Circle.kt") - assertTrue(circleFile.exists(), "Circle.kt should exist") - val circleContent = circleFile.readText() - assertTrue(circleContent.contains(": Shape"), "Circle.kt should implement Shape") + assertFalse(circleFile.exists(), "Circle.kt should NOT exist as a separate file") + assertTrue(shapeContent.contains("data class Circle"), "Shape.kt should contain nested Circle") + assertTrue(shapeContent.contains("data class Square"), "Shape.kt should contain nested Square") // SerializersModule file exists val moduleFile = outputDir.resolve("SerializersModule.kt")