diff --git a/pydantic_xml/serializers/factories/union.py b/pydantic_xml/serializers/factories/union.py
index 248a9c0..5490526 100644
--- a/pydantic_xml/serializers/factories/union.py
+++ b/pydantic_xml/serializers/factories/union.py
@@ -1,3 +1,4 @@
+from copy import deepcopy
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pydantic as pd
@@ -143,7 +144,8 @@ def deserialize(
def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serializer:
choice_families: Set[SchemaTypeFamily] = set()
- for choice_schema in schema['choices']:
+ flattened_schema = _flatten_choice_schemas(deepcopy(schema), ctx)
+ for choice_schema in flattened_schema['choices']:
if isinstance(choice_schema, tuple):
choice_schema, label = choice_schema
@@ -163,8 +165,35 @@ def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serial
choice_family = choice_families.pop()
if choice_family is SchemaTypeFamily.MODEL:
- return ModelSerializer.from_core_schema(schema, ctx)
+ return ModelSerializer.from_core_schema(flattened_schema, ctx)
elif choice_family is SchemaTypeFamily.PRIMITIVE:
- return PrimitiveTypeSerializer.from_core_schema(schema, ctx)
+ return PrimitiveTypeSerializer.from_core_schema(flattened_schema, ctx)
else:
raise AssertionError("unreachable")
+
+
+def _flatten_choice_schemas(schema: pcs.UnionSchema, ctx: Serializer.Context) -> pcs.UnionSchema:
+ """
+ Flatten nested union choice_schemas into their components, leave others as they are
+ """
+ choice_schemas = schema['choices']
+ flattened_schemas = []
+ seen_refs = set()
+ while choice_schemas:
+ choice_schema = original_schema = choice_schemas.pop()
+ if isinstance(choice_schema, tuple):
+ choice_schema, label = choice_schema
+ ref = choice_schema.get('schema_ref', choice_schema.get('ref'))
+ if ref in seen_refs:
+ continue
+ if ref:
+ seen_refs.add(ref)
+
+ if choice_schema['type'] == 'definition-ref':
+ choice_schema = ctx.definitions.get(choice_schema['schema_ref'])
+ if choice_schema['type'] == 'union':
+ choice_schemas.extend(choice_schema['choices'])
+ else:
+ flattened_schemas.append(original_schema)
+ schema['choices'] = flattened_schemas
+ return schema
diff --git a/tests/test_unions.py b/tests/test_unions.py
index b55f6d1..d6c1173 100644
--- a/tests/test_unions.py
+++ b/tests/test_unions.py
@@ -1,9 +1,9 @@
import sys
-from typing import List, Literal, Tuple, Union
+from typing import List, Literal, Tuple, Union, Annotated
import pytest
from helpers import assert_xml_equal
-from pydantic import Field
+from pydantic import Field, BeforeValidator
from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element
@@ -453,3 +453,77 @@ class TestModel(BaseXmlModel, tag='model'):
actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)
+
+
+type Literals1 = Literal[1, 2]
+type Literals2 = Literal[5, 6]
+type Literals3 = Literal[0]
+type LiteralUnion1 = Literals1 | Literals2
+type LiteralUnion2 = LiteralUnion1 | Literals3
+type LiteralUnion3 = LiteralUnion2 | LiteralUnion1
+
+def test_nested_primitive_union():
+ class Model(BaseXmlModel, tag='model'):
+ element1: Annotated[LiteralUnion3, BeforeValidator(int)] = element(tag='testLiteral')
+
+ xml = '''
+
+ 5
+
+ '''
+
+ actual = Model.from_xml(xml)
+ expected = Model(element1=5)
+ assert actual == expected
+
+ actual_xml = actual.to_xml()
+ assert_xml_equal(actual_xml, xml)
+
+
+class UnionTestModel1(BaseXmlModel, tag='model1'):
+ element1: int = element(tag='element1')
+
+class UnionTestModel2(BaseXmlModel, tag='model2'):
+ element2: str = element(tag='element2')
+
+type UnionTestModelUnion = UnionTestModel1 | UnionTestModel2
+
+def test_nested_model_union():
+ class UnionTestModel3(BaseXmlModel, tag='model3'):
+ element3: str = element(tag='element3')
+
+ class TopLevelModel(BaseXmlModel, tag='topLevel'):
+ top_element1: UnionTestModel3 | UnionTestModelUnion = element()
+
+ xml_1 = """
+
+ 1
+
+ """
+ actual_1 = TopLevelModel.from_xml(xml_1)
+ expected_1 = TopLevelModel(top_element1=UnionTestModel1(element1=1))
+ assert actual_1 == expected_1
+ actual_1_xml = actual_1.to_xml()
+ assert_xml_equal(actual_1_xml, xml_1)
+
+ xml_2 = """
+
+ element 2
+
+ """
+ actual_2 = TopLevelModel.from_xml(xml_2)
+ expected_2 = TopLevelModel(top_element1=UnionTestModel2(element2='element 2'))
+ assert actual_2 == expected_2
+ actual_2_xml = actual_2.to_xml()
+ assert_xml_equal(actual_2_xml, xml_2)
+
+ xml_3 = """
+
+ element 3
+
+ """
+ actual_3 = TopLevelModel.from_xml(xml_3)
+ expected_3 = TopLevelModel(top_element1=UnionTestModel3(element3='element 3'))
+ assert actual_3 == expected_3
+ actual_3_xml = actual_3.to_xml()
+ assert_xml_equal(actual_3_xml, xml_3)