Reputation: 734
I'm adding type hints to my Python code and was wondering what the proper way to type-hint a loaded YAML file since it's a dictionary of any number of dictionaries.
Is there a better way to type-hint returning a loaded YAML file than Dict[str, Dict[str, Any]]
?
Here's the function:
def load_yaml(yaml_in: str) -> Dict[str, Dict[str, Any]]:
return yaml.load(open(yaml_in), Loader=yaml.FullLoader)
Here's an example of the YAML file being loaded:
VariableMap:
var1: 'time'
var2: 'param_name'
GlobalVariables:
limits:
x-min:
x-max:
y-min:
y-max:
Plots:
plot1:
file:
x_data: 'date'
y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
{param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
labels:
title: {label: 'title', fontsize: '9'}
x-axis: {xlabel: 'x-label', fontsize: '9'}
y-axis: {ylabel: 'y-label', fontsize: '9'}
limits:
x-min: 0
x-max: 100
y-min:
y-max:
Figures:
fig1:
shape: [1, 1]
size: [6, 8]
plots: ['plot1']
Upvotes: 7
Views: 7339
Reputation: 11642
I would suggest looking into the dataclass-wizard library, as it might be helpful for this task. In particular, there exists the YAMLWizard
, which overall simplifies working with YAML data.
First things first, I would suggest defining a variable for the YAML data:
yaml_string = """
VariableMap:
var1: 'time'
var2: 'param_name'
GlobalVariables:
limits:
x-min:
x-max:
y-min:
y-max:
Plots:
plot1:
file:
x_data: 'date'
y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
{param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
labels:
title: {label: 'title', fontsize: '9'}
x-axis: {xlabel: 'x-label', fontsize: '9'}
y-axis: {ylabel: 'y-label', fontsize: '9'}
limits:
x-min: 0
x-max: 100
y-min:
y-max:
Figures:
fig1:
shape: [1, 1]
size: [6, 8]
plots: ['plot1']
"""
Now you can use the CLI utility to generate a very rough dataclass schema, as so:
import json
import yaml
import dataclass_wizard.wizard_cli as cli
data = yaml.safe_load(yaml_string)
print(cli.PyCodeGenerator(file_contents=json.dumps(data), experimental=True).py_code)
This outputs something like the below. Note that I've went ahead and cleaned up some issues like duplicate classes and "unknown" types (such as for y-min
and y-max
for example).
The full dataclass schema:
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from dataclass_wizard import YAMLWizard
@dataclass
class Data(YAMLWizard):
"""
Data dataclass
"""
variable_map: VariableMap
global_variables: GlobalVariables
plots: Plots
figures: Figures
@dataclass
class VariableMap:
"""
VariableMap dataclass
"""
var1: str
var2: str
@dataclass
class GlobalVariables:
"""
GlobalVariables dataclass
"""
limits: Limits
@dataclass
class Plots:
"""
Plots dataclass
"""
plot1: Plot
@dataclass
class Plot:
"""
Plot1 dataclass
"""
file: Any
x_data: str
y_data: list[YDatum]
labels: Labels
limits: Limits
@dataclass
class YDatum:
"""
YDatum dataclass
"""
param: str
label: str
color: str
linestyle: str
@dataclass
class Labels:
"""
Labels dataclass
"""
title: Title
x_axis: XAxis
y_axis: YAxis
@dataclass
class Title:
"""
Title dataclass
"""
label: str
fontsize: int | str
@dataclass
class XAxis:
"""
XAxis dataclass
"""
xlabel: str
fontsize: int | str
@dataclass
class YAxis:
"""
YAxis dataclass
"""
ylabel: str
fontsize: int | str
@dataclass
class Limits:
"""
Limits dataclass
"""
x_min: int
x_max: int
y_min: int
y_max: int
@dataclass
class Figures:
"""
Figures dataclass
"""
fig1: Fig1
@dataclass
class Fig1:
"""
Fig1 dataclass
"""
shape: list[int]
size: list[int]
plots: list[str]
Now we can load the YAML string to a nested Data
object, as below:
import pprint
data = Data.from_yaml(yaml_string)
pprint.pprint(data)
Output:
Data(variable_map=VariableMap(var1='time', var2='param_name'),
global_variables=GlobalVariables(limits=Limits(x_min=0,
x_max=0,
y_min=0,
y_max=0)),
plots=Plots(plot1=Plot(file=None,
x_data='date',
y_data=[YDatum(param='param1',
label='param1',
color='red',
linestyle='-'),
YDatum(param='param2',
label='param2',
color='black',
linestyle='--')],
labels=Labels(title=Title(label='title',
fontsize='9'),
x_axis=XAxis(xlabel='x-label',
fontsize='9'),
y_axis=YAxis(ylabel='y-label',
fontsize='9')),
limits=Limits(x_min=0, x_max=100, y_min=0, y_max=0))),
figures=Figures(fig1=Fig1(shape=[1, 1], size=[6, 8], plots=['plot1'])))
I noticed that some classes, in particular XLabel
and YLabel
, contain essentially the same fields, however have slightly different names for the fields in YAML data.
If desired, such classes could actually be merged into a single class declaration. Then, we can use a key mapping approach such as with json_field()
to define alias key names to use when loading the YAML data.
For example:
from dataclass_wizard import json_field
@dataclass
class Axis:
# noinspection PyDataclass
label: str = json_field(('xlabel', 'ylabel'))
fontsize: int | str
Upvotes: 2
Reputation: 3565
I would use dataclass and dataclass_json for this. Then the return type can be an actual class. You can also use the LetterCase
option to map from the formats in your yaml file to the default Python naming scheme - snake case.
Something like this:
from dataclasses import dataclass
from dataclasses_json import dataclass_json, LetterCase, Undefined
from typing import Optional
@dataclass_json
@dataclass
class VariableMap:
var1: str
var2: str
@dataclass_json(letter_case=LetterCase.KEBAB, )
@dataclass
class Limits:
x_min: Optional[int]
x_max: Optional[int]
y_min: Optional[int]
y_max: Optional[int]
@dataclass_json
@dataclass
class GlobalVariables:
limits: Limits
@dataclass_json
@dataclass
class TitleLabel:
label: str
fontsize: str
@dataclass_json
@dataclass
class XAxisLabel:
xlabel: str
fontsize: str
@dataclass_json
@dataclass
class YAxisLabel:
ylabel: str
fontsize: str
@dataclass_json(letter_case=LetterCase.KEBAB)
@dataclass
class Labels:
title: TitleLabel
x_axis: XAxisLabel
y_axis: YAxisLabel
@dataclass_json
@dataclass
class Param:
param: str
label: str
color: str
linestyle: str
@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Plot:
file: Optional[str]
x_data: str
y_data: list[Param]
labels: Labels
limits: Limits
@dataclass_json
@dataclass
class Figure:
shape: list[int]
size: list[int]
plots: list[str]
@dataclass_json
@dataclass
class Plots:
plot1: Plot
@dataclass_json
@dataclass
class Figures:
fig1: Figure
@dataclass_json(letter_case=LetterCase.PASCAL, undefined=Undefined.EXCLUDE)
@dataclass
class YamlRoot:
variable_map: VariableMap
global_variables: GlobalVariables
plots: Plots
figures: Figures
import yaml
def load_yaml(yaml_content: str) -> YamlRoot:
d = yaml.safe_load(yaml_content)
return YamlRoot.from_dict(d, infer_missing=True)
if __name__ == "__main__":
yaml_text = """
VariableMap:
var1: 'time'
var2: 'param_name'
GlobalVariables:
...
"""
print(load_yaml(yaml_text))
"""
Result (cleaned for readability):
YamlRoot(
variable_map=VariableMap(var1='time', var2='param_name'),
global_variables=GlobalVariables(
limits=Limits(x_min=None, x_max=None, y_min=None, y_max=None)
),
plots=Plots(
plot1=Plot(
file=None,
x_data='date',
y_data=[
Param(param='param1', label='param1', color='red', linestyle='-'),
Param(param='param2', label='param2', color='black', linestyle='--')
],
labels=Labels(
title=TitleLabel(label='title', fontsize='9'),
x_axis=XAxisLabel(xlabel='x-label', fontsize='9'),
y_axis=YAxisLabel(ylabel='y-label', fontsize='9')
),
limits=Limits(x_min=0, x_max=100, y_min=None, y_max=None)
)
),
figures=Figures(
fig1=Figure(shape=[1, 1], size=[6, 8], plots=['plot1'])
)
)
"""
Upvotes: 1