Imad
Imad

Reputation: 2751

How can I detect direct nested fields within XSD to correctly convert them to a spark schema?

I'm trying to create a function to convert any xsd schema into a pyspark viable schema... Since I'm not an expert on xml and all the intricacies of choosing the right namespace, I defer to your wisdom on this matter. Here are the functions I implemented as of now :

def convert_simple_xsd_type_to_pyspark(xsd_type):
    """
    Maps XML Schema Definition (XSD) types to corresponding PySpark data types.

    Arguments:
        xsd_type (str): The XSD type to be mapped.

    Returns:
        pyspark.sql.types.DataType: The corresponding PySpark data type for the given XSD type.
    """
    type_mapping = {
        "xs:string": StringType(),
        "xs:int": IntegerType(),
        "xs:date": DateType(),
        "AlphaNum_Type": StringType(),
        "Num_Type": IntegerType(),
        "Parution_Type": StringType(),
        # More mappings as per your schema definitions
    }
    return type_mapping.get(xsd_type, StringType())


def recursively_create_struct_from_xsd(element):
    """
    Recursively creates a StructType schema from an XML schema element.

    Arguments:
        element (xml.etree.ElementTree.Element): The XML element to convert into a StructType.

    Returns:
        pyspark.sql.types.StructType: A StructType representing the structure of the XML schema.
    """
    fields = []
    for sub_element in element.findall("./{http://www.w3.org/2001/XMLSchema}element"):
        name = sub_element.get("name")
        max_occurs = sub_element.get("maxOccurs", "1")
        min_occurs = sub_element.get("minOccurs", "1")
        xsd_type = sub_element.get("type")
        nullable = min_occurs == "0"

        if xsd_type:
            spark_type = convert_simple_xsd_type_to_pyspark(xsd_type)
        else:
            complex_type = sub_element.find("./{http://www.w3.org/2001/XMLSchema}complexType")
            if complex_type:
                # Recursively handle complex types
                spark_type = recursively_create_struct_from_xsd(complex_type)
            else:
                spark_type = StringType()

        # Check if the element can occur multiple times and wrap in ArrayType if so
        if max_occurs == "unbounded":
            spark_type = ArrayType(spark_type, containsNull=True)

        fields.append(StructField(name, spark_type, nullable=nullable))

    return StructType(fields)

Here's the (initial) test I'm trying to pass using pytest:

import xml.etree.ElementTree as ET
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DateType, IntegerType

from utils.xml.spark_schema import (
    recursively_create_struct_from_xsd,
)


def test_recursively_create_struct_from_fictional_xsd():
    # Define a simple XML schema as a string and parse it
    xml_schema = """
    <xs:schema xmlns:xs='http://www.w3.org/2001/XMLSchema'>
        <xs:element name='Book' type='xs:string' minOccurs='1' />
        <xs:element name='PublicationDate' type='xs:date' minOccurs='1' />
        <xs:element name='Authors'>
            <xs:complexType>
                <xs:sequence>
                    <xs:element name='Author' type='xs:string' minOccurs='1' maxOccurs='unbounded' />
                </xs:sequence>
            </xs:complexType>
        </xs:element>
    </xs:schema>
    """
    element = ET.fromstring(xml_schema)

    # Use the function to generate the schema
    schema = recursively_create_struct_from_xsd(element)

    # Define the expected schema explicitly naming kwargs
    expected_schema = StructType(
        [
            StructField(name="Book", dataType=StringType(), nullable=False),
            StructField(name="PublicationDate", dataType=DateType(), nullable=False),
            StructField(
                name="Authors",
                dataType=StructType(
                    [StructField(name="Author", dataType=ArrayType(StringType(), containsNull=True), nullable=False)]
                ),
                nullable=False,
            ),
        ]
    )

    # Assert to check schema correctness
    assert str(schema) == str(expected_schema)

Sample of xml data

<BookDetails>
    <Book>The Brothers Karamazov</Book>
    <PublicationDate>1880-01-01</PublicationDate>
    <Authors>
        <Author>Fyodor Dostoevsky</Author>
        <Author>Constance Garnett (Translator)</Author>
        <Author>Richard Pevear (Translator)</Author>
        <Author>Larissa Volokhonsky (Translator)</Author>
    </Authors>
</BookDetails>

There's a problem in regards to the loop of elements element.findall(...).... when I use "./{http://www.w3.org/2001/XMLSchema}element" the complex an nested elements get skipped and I end up with the following schema :

StructType([StructField('Book', StringType(), False), StructField('PublicationDate', DateType(), False), StructField('Authors', StructType([]), False)])

When I use ".//{http://www.w3.org/2001/XMLSchema}element" I end up with duplicates of the nested elements outside of its intended context:

StructType([StructField('Book', StringType(), False), StructField('PublicationDate', DateType(), False), StructField('Authors', StructType([StructField('Author', ArrayType(StringType(), True), False)]), False), StructField('Author', ArrayType(StringType(), True), False)])

What do I use to correctly detect and process Author within Authors and not duplicate it outside of its intended context (Authors)? Of course, the solution should be able to work no matter the depth of the nesting.

Upvotes: 0

Views: 81

Answers (0)

Related Questions