Simon1
Simon1

Reputation: 734

What is the proper way to add type hints after loading a YAML file?

I'm adding type hints to my Python code and was wondering what the proper way to type-hint a loaded YAML file since it's a dictionary of any number of dictionaries.

Is there a better way to type-hint returning a loaded YAML file than Dict[str, Dict[str, Any]]?

Here's the function:

def load_yaml(yaml_in: str) -> Dict[str, Dict[str, Any]]:
    return yaml.load(open(yaml_in), Loader=yaml.FullLoader)

Here's an example of the YAML file being loaded:

VariableMap:
    var1: 'time'
    var2: 'param_name'

GlobalVariables:
    limits:
        x-min:
        x-max:
        y-min:
        y-max:

Plots:
    plot1:
        file: 
        x_data: 'date'
        y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
                 {param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
        labels:
            title: {label: 'title', fontsize: '9'}
            x-axis: {xlabel: 'x-label', fontsize: '9'}
            y-axis: {ylabel: 'y-label', fontsize: '9'}
        limits:
            x-min: 0
            x-max: 100
            y-min:
            y-max:

Figures:
    fig1:
        shape: [1, 1]
        size: [6, 8]
        plots: ['plot1']

Upvotes: 7

Views: 7339

Answers (2)

Wizard.Ritvik
Wizard.Ritvik

Reputation: 11642

I would suggest looking into the dataclass-wizard library, as it might be helpful for this task. In particular, there exists the YAMLWizard, which overall simplifies working with YAML data.

First things first, I would suggest defining a variable for the YAML data:

yaml_string = """
VariableMap:
    var1: 'time'
    var2: 'param_name'

GlobalVariables:
    limits:
        x-min:
        x-max:
        y-min:
        y-max:

Plots:
    plot1:
        file:
        x_data: 'date'
        y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
                 {param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
        labels:
            title: {label: 'title', fontsize: '9'}
            x-axis: {xlabel: 'x-label', fontsize: '9'}
            y-axis: {ylabel: 'y-label', fontsize: '9'}
        limits:
            x-min: 0
            x-max: 100
            y-min:
            y-max:

Figures:
    fig1:
        shape: [1, 1]
        size: [6, 8]
        plots: ['plot1']
"""

Now you can use the CLI utility to generate a very rough dataclass schema, as so:

import json
import yaml

import dataclass_wizard.wizard_cli as cli


data = yaml.safe_load(yaml_string)

print(cli.PyCodeGenerator(file_contents=json.dumps(data), experimental=True).py_code)

This outputs something like the below. Note that I've went ahead and cleaned up some issues like duplicate classes and "unknown" types (such as for y-min and y-max for example).

The full dataclass schema:

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from dataclass_wizard import YAMLWizard


@dataclass
class Data(YAMLWizard):
    """
    Data dataclass

    """
    variable_map: VariableMap
    global_variables: GlobalVariables
    plots: Plots
    figures: Figures


@dataclass
class VariableMap:
    """
    VariableMap dataclass

    """
    var1: str
    var2: str


@dataclass
class GlobalVariables:
    """
    GlobalVariables dataclass

    """
    limits: Limits


@dataclass
class Plots:
    """
    Plots dataclass

    """
    plot1: Plot


@dataclass
class Plot:
    """
    Plot1 dataclass

    """
    file: Any
    x_data: str
    y_data: list[YDatum]
    labels: Labels
    limits: Limits


@dataclass
class YDatum:
    """
    YDatum dataclass

    """
    param: str
    label: str
    color: str
    linestyle: str


@dataclass
class Labels:
    """
    Labels dataclass

    """
    title: Title
    x_axis: XAxis
    y_axis: YAxis


@dataclass
class Title:
    """
    Title dataclass

    """
    label: str
    fontsize: int | str


@dataclass
class XAxis:
    """
    XAxis dataclass

    """
    xlabel: str
    fontsize: int | str


@dataclass
class YAxis:
    """
    YAxis dataclass

    """
    ylabel: str
    fontsize: int | str


@dataclass
class Limits:
    """
    Limits dataclass

    """
    x_min: int
    x_max: int
    y_min: int
    y_max: int


@dataclass
class Figures:
    """
    Figures dataclass

    """
    fig1: Fig1


@dataclass
class Fig1:
    """
    Fig1 dataclass

    """
    shape: list[int]
    size: list[int]
    plots: list[str]

Now we can load the YAML string to a nested Data object, as below:

import pprint

data = Data.from_yaml(yaml_string)
pprint.pprint(data)

Output:

Data(variable_map=VariableMap(var1='time', var2='param_name'),
     global_variables=GlobalVariables(limits=Limits(x_min=0,
                                                    x_max=0,
                                                    y_min=0,
                                                    y_max=0)),
     plots=Plots(plot1=Plot(file=None,
                            x_data='date',
                            y_data=[YDatum(param='param1',
                                           label='param1',
                                           color='red',
                                           linestyle='-'),
                                    YDatum(param='param2',
                                           label='param2',
                                           color='black',
                                           linestyle='--')],
                            labels=Labels(title=Title(label='title',
                                                      fontsize='9'),
                                          x_axis=XAxis(xlabel='x-label',
                                                       fontsize='9'),
                                          y_axis=YAxis(ylabel='y-label',
                                                       fontsize='9')),
                            limits=Limits(x_min=0, x_max=100, y_min=0, y_max=0))),
     figures=Figures(fig1=Fig1(shape=[1, 1], size=[6, 8], plots=['plot1'])))

Observations

I noticed that some classes, in particular XLabel and YLabel, contain essentially the same fields, however have slightly different names for the fields in YAML data.

If desired, such classes could actually be merged into a single class declaration. Then, we can use a key mapping approach such as with json_field() to define alias key names to use when loading the YAML data.

For example:

from dataclass_wizard import json_field

@dataclass
class Axis:
    # noinspection PyDataclass
    label: str = json_field(('xlabel', 'ylabel'))
    fontsize: int | str

Upvotes: 2

blueteeth
blueteeth

Reputation: 3565

I would use dataclass and dataclass_json for this. Then the return type can be an actual class. You can also use the LetterCase option to map from the formats in your yaml file to the default Python naming scheme - snake case.

Something like this:

from dataclasses import dataclass
from dataclasses_json import dataclass_json, LetterCase, Undefined
from typing import Optional

@dataclass_json
@dataclass
class VariableMap:
    var1: str
    var2: str


@dataclass_json(letter_case=LetterCase.KEBAB, )
@dataclass
class Limits:
    x_min: Optional[int]
    x_max: Optional[int]
    y_min: Optional[int]
    y_max: Optional[int]

@dataclass_json
@dataclass
class GlobalVariables:
    limits: Limits

@dataclass_json
@dataclass
class TitleLabel:
    label: str
    fontsize: str

@dataclass_json
@dataclass
class XAxisLabel:
    xlabel: str
    fontsize: str

@dataclass_json
@dataclass
class YAxisLabel:
    ylabel: str
    fontsize: str

@dataclass_json(letter_case=LetterCase.KEBAB)
@dataclass
class Labels:
    title: TitleLabel
    x_axis: XAxisLabel
    y_axis: YAxisLabel


@dataclass_json
@dataclass
class Param:
    param: str
    label: str
    color: str
    linestyle: str

@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Plot:
    file: Optional[str]
    x_data: str
    y_data: list[Param]
    labels: Labels
    limits: Limits

@dataclass_json
@dataclass
class Figure:
    shape: list[int]
    size: list[int]
    plots: list[str]

@dataclass_json
@dataclass
class Plots:
    plot1: Plot

@dataclass_json
@dataclass
class Figures:
    fig1: Figure

@dataclass_json(letter_case=LetterCase.PASCAL, undefined=Undefined.EXCLUDE)
@dataclass
class YamlRoot:
    variable_map: VariableMap
    global_variables: GlobalVariables
    plots: Plots
    figures: Figures


import yaml

def load_yaml(yaml_content: str) -> YamlRoot:
    d = yaml.safe_load(yaml_content)
    return YamlRoot.from_dict(d, infer_missing=True)

if __name__ == "__main__":
    
    yaml_text = """
VariableMap:
    var1: 'time'
    var2: 'param_name'

GlobalVariables:
    ...
"""
    
    print(load_yaml(yaml_text))


"""
Result (cleaned for readability):

YamlRoot(
  variable_map=VariableMap(var1='time', var2='param_name'),
  global_variables=GlobalVariables(
    limits=Limits(x_min=None, x_max=None, y_min=None, y_max=None)
  ),
  plots=Plots(
    plot1=Plot(
      file=None,
      x_data='date',
      y_data=[
        Param(param='param1', label='param1', color='red', linestyle='-'),
        Param(param='param2', label='param2', color='black', linestyle='--')
      ],
      labels=Labels(
        title=TitleLabel(label='title', fontsize='9'),
        x_axis=XAxisLabel(xlabel='x-label', fontsize='9'),
        y_axis=YAxisLabel(ylabel='y-label', fontsize='9')
      ),
      limits=Limits(x_min=0, x_max=100, y_min=None, y_max=None)
    )
  ),
  figures=Figures(
    fig1=Figure(shape=[1, 1], size=[6, 8], plots=['plot1'])
  )
)
"""

Upvotes: 1

Related Questions