diff --git a/pydantic_xml/serializers/factories/union.py b/pydantic_xml/serializers/factories/union.py index 248a9c0..5490526 100644 --- a/pydantic_xml/serializers/factories/union.py +++ b/pydantic_xml/serializers/factories/union.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Any, Dict, List, Optional, Set, Tuple, Union import pydantic as pd @@ -143,7 +144,8 @@ def deserialize( def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serializer: choice_families: Set[SchemaTypeFamily] = set() - for choice_schema in schema['choices']: + flattened_schema = _flatten_choice_schemas(deepcopy(schema), ctx) + for choice_schema in flattened_schema['choices']: if isinstance(choice_schema, tuple): choice_schema, label = choice_schema @@ -163,8 +165,35 @@ def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serial choice_family = choice_families.pop() if choice_family is SchemaTypeFamily.MODEL: - return ModelSerializer.from_core_schema(schema, ctx) + return ModelSerializer.from_core_schema(flattened_schema, ctx) elif choice_family is SchemaTypeFamily.PRIMITIVE: - return PrimitiveTypeSerializer.from_core_schema(schema, ctx) + return PrimitiveTypeSerializer.from_core_schema(flattened_schema, ctx) else: raise AssertionError("unreachable") + + +def _flatten_choice_schemas(schema: pcs.UnionSchema, ctx: Serializer.Context) -> pcs.UnionSchema: + """ + Flatten nested union choice_schemas into their components, leave others as they are + """ + choice_schemas = schema['choices'] + flattened_schemas = [] + seen_refs = set() + while choice_schemas: + choice_schema = original_schema = choice_schemas.pop() + if isinstance(choice_schema, tuple): + choice_schema, label = choice_schema + ref = choice_schema.get('schema_ref', choice_schema.get('ref')) + if ref in seen_refs: + continue + if ref: + seen_refs.add(ref) + + if choice_schema['type'] == 'definition-ref': + choice_schema = ctx.definitions.get(choice_schema['schema_ref']) + if choice_schema['type'] == 'union': + choice_schemas.extend(choice_schema['choices']) + else: + flattened_schemas.append(original_schema) + schema['choices'] = flattened_schemas + return schema diff --git a/tests/test_unions.py b/tests/test_unions.py index b55f6d1..d6c1173 100644 --- a/tests/test_unions.py +++ b/tests/test_unions.py @@ -1,9 +1,9 @@ import sys -from typing import List, Literal, Tuple, Union +from typing import List, Literal, Tuple, Union, Annotated import pytest from helpers import assert_xml_equal -from pydantic import Field +from pydantic import Field, BeforeValidator from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element @@ -453,3 +453,77 @@ class TestModel(BaseXmlModel, tag='model'): actual_xml = actual_obj.to_xml() assert_xml_equal(actual_xml, xml) + + +type Literals1 = Literal[1, 2] +type Literals2 = Literal[5, 6] +type Literals3 = Literal[0] +type LiteralUnion1 = Literals1 | Literals2 +type LiteralUnion2 = LiteralUnion1 | Literals3 +type LiteralUnion3 = LiteralUnion2 | LiteralUnion1 + +def test_nested_primitive_union(): + class Model(BaseXmlModel, tag='model'): + element1: Annotated[LiteralUnion3, BeforeValidator(int)] = element(tag='testLiteral') + + xml = ''' + + 5 + + ''' + + actual = Model.from_xml(xml) + expected = Model(element1=5) + assert actual == expected + + actual_xml = actual.to_xml() + assert_xml_equal(actual_xml, xml) + + +class UnionTestModel1(BaseXmlModel, tag='model1'): + element1: int = element(tag='element1') + +class UnionTestModel2(BaseXmlModel, tag='model2'): + element2: str = element(tag='element2') + +type UnionTestModelUnion = UnionTestModel1 | UnionTestModel2 + +def test_nested_model_union(): + class UnionTestModel3(BaseXmlModel, tag='model3'): + element3: str = element(tag='element3') + + class TopLevelModel(BaseXmlModel, tag='topLevel'): + top_element1: UnionTestModel3 | UnionTestModelUnion = element() + + xml_1 = """ + + 1 + + """ + actual_1 = TopLevelModel.from_xml(xml_1) + expected_1 = TopLevelModel(top_element1=UnionTestModel1(element1=1)) + assert actual_1 == expected_1 + actual_1_xml = actual_1.to_xml() + assert_xml_equal(actual_1_xml, xml_1) + + xml_2 = """ + + element 2 + + """ + actual_2 = TopLevelModel.from_xml(xml_2) + expected_2 = TopLevelModel(top_element1=UnionTestModel2(element2='element 2')) + assert actual_2 == expected_2 + actual_2_xml = actual_2.to_xml() + assert_xml_equal(actual_2_xml, xml_2) + + xml_3 = """ + + element 3 + + """ + actual_3 = TopLevelModel.from_xml(xml_3) + expected_3 = TopLevelModel(top_element1=UnionTestModel3(element3='element 3')) + assert actual_3 == expected_3 + actual_3_xml = actual_3.to_xml() + assert_xml_equal(actual_3_xml, xml_3)