Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions pydantic_xml/serializers/factories/union.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import pydantic as pd
Expand Down Expand Up @@ -143,7 +144,8 @@ def deserialize(

def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serializer:
choice_families: Set[SchemaTypeFamily] = set()
for choice_schema in schema['choices']:
flattened_schema = _flatten_choice_schemas(deepcopy(schema), ctx)
for choice_schema in flattened_schema['choices']:
if isinstance(choice_schema, tuple):
choice_schema, label = choice_schema

Expand All @@ -163,8 +165,35 @@ def from_core_schema(schema: pcs.UnionSchema, ctx: Serializer.Context) -> Serial

choice_family = choice_families.pop()
if choice_family is SchemaTypeFamily.MODEL:
return ModelSerializer.from_core_schema(schema, ctx)
return ModelSerializer.from_core_schema(flattened_schema, ctx)
elif choice_family is SchemaTypeFamily.PRIMITIVE:
return PrimitiveTypeSerializer.from_core_schema(schema, ctx)
return PrimitiveTypeSerializer.from_core_schema(flattened_schema, ctx)
else:
raise AssertionError("unreachable")


def _flatten_choice_schemas(schema: pcs.UnionSchema, ctx: Serializer.Context) -> pcs.UnionSchema:
"""
Flatten nested union choice_schemas into their components, leave others as they are
"""
choice_schemas = schema['choices']
flattened_schemas = []
seen_refs = set()
while choice_schemas:
choice_schema = original_schema = choice_schemas.pop()
if isinstance(choice_schema, tuple):
choice_schema, label = choice_schema
ref = choice_schema.get('schema_ref', choice_schema.get('ref'))
if ref in seen_refs:
continue
if ref:
seen_refs.add(ref)

if choice_schema['type'] == 'definition-ref':
choice_schema = ctx.definitions.get(choice_schema['schema_ref'])
if choice_schema['type'] == 'union':
choice_schemas.extend(choice_schema['choices'])
else:
flattened_schemas.append(original_schema)
schema['choices'] = flattened_schemas
return schema
78 changes: 76 additions & 2 deletions tests/test_unions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from typing import List, Literal, Tuple, Union
from typing import List, Literal, Tuple, Union, Annotated

import pytest
from helpers import assert_xml_equal
from pydantic import Field
from pydantic import Field, BeforeValidator

from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element

Expand Down Expand Up @@ -453,3 +453,77 @@ class TestModel(BaseXmlModel, tag='model'):

actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)


type Literals1 = Literal[1, 2]
type Literals2 = Literal[5, 6]
type Literals3 = Literal[0]
type LiteralUnion1 = Literals1 | Literals2
type LiteralUnion2 = LiteralUnion1 | Literals3
type LiteralUnion3 = LiteralUnion2 | LiteralUnion1

def test_nested_primitive_union():
class Model(BaseXmlModel, tag='model'):
element1: Annotated[LiteralUnion3, BeforeValidator(int)] = element(tag='testLiteral')

xml = '''
<model>
<testLiteral>5</testLiteral>
</model>
'''

actual = Model.from_xml(xml)
expected = Model(element1=5)
assert actual == expected

actual_xml = actual.to_xml()
assert_xml_equal(actual_xml, xml)


class UnionTestModel1(BaseXmlModel, tag='model1'):
element1: int = element(tag='element1')

class UnionTestModel2(BaseXmlModel, tag='model2'):
element2: str = element(tag='element2')

type UnionTestModelUnion = UnionTestModel1 | UnionTestModel2

def test_nested_model_union():
class UnionTestModel3(BaseXmlModel, tag='model3'):
element3: str = element(tag='element3')

class TopLevelModel(BaseXmlModel, tag='topLevel'):
top_element1: UnionTestModel3 | UnionTestModelUnion = element()

xml_1 = """
<topLevel>
<model1><element1>1</element1></model1>
</topLevel>
"""
actual_1 = TopLevelModel.from_xml(xml_1)
expected_1 = TopLevelModel(top_element1=UnionTestModel1(element1=1))
assert actual_1 == expected_1
actual_1_xml = actual_1.to_xml()
assert_xml_equal(actual_1_xml, xml_1)

xml_2 = """
<topLevel>
<model2><element2>element 2</element2></model2>
</topLevel>
"""
actual_2 = TopLevelModel.from_xml(xml_2)
expected_2 = TopLevelModel(top_element1=UnionTestModel2(element2='element 2'))
assert actual_2 == expected_2
actual_2_xml = actual_2.to_xml()
assert_xml_equal(actual_2_xml, xml_2)

xml_3 = """
<topLevel>
<model3><element3>element 3</element3></model3>
</topLevel>
"""
actual_3 = TopLevelModel.from_xml(xml_3)
expected_3 = TopLevelModel(top_element1=UnionTestModel3(element3='element 3'))
assert actual_3 == expected_3
actual_3_xml = actual_3.to_xml()
assert_xml_equal(actual_3_xml, xml_3)