Reputation: 1
I am trying to save the output images (graphs) I get when I use explain()
in H2O models. Currently I am just saving the SHAP output using the model.shap_summary_plot(test, save_plot_path=`shap_summary.png`)
. There is no save_plot_path for explain.
import h2o
from h2o.automl import H2OAutoML
h2o.init()
df = h2o.import_file("https://h2o-public-test-data.s3.amazonaws.com/smalldata/wine/winequality-redwhite-no-BOM.csv")
response = "quality"
predictors = [
"fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", "free sulfur dioxide",
"total sulfur dioxide", "density", "pH", "sulphates", "alcohol", "type"
]
train, test = df.split_frame(seed=1)
aml = H2OAutoML(max_runtime_secs=120, seed=1)
aml.train(x=predictors, y=response, training_frame=train)
leader_model = aml.leader
leader_model.explain(test) # save this output
However I want to save all the graphs generated via explain()
instead of creating them individually. Also I want it to run as a script and not as a jupyter notebook.
Here is sample code,(edited Explain-wine-example)
Upvotes: 0
Views: 821
Reputation: 356
In H2O3, the model.explain()
will return h2o.explanation._explain.H2OExplanation
object. You can iterate through it to save your plots.
param render
, if True
, render the model explanations otherwise model explanations are just returned.
I was able to do it with the below function
tested with h2o version '3.36.1.2'
def save_explain_plots(model, data):
obj = model.explain(data, render=False)
for key in obj.keys():
print(f"saving {key} plots")
if not obj.get(key).get("plots"):
continue
plots = obj.get(key).get("plots").keys()
os.makedirs(f"./images/{key}", exist_ok=True)
for plot in plots:
fig = obj.get(key).get("plots").get(plot).figure()
fig.savefig(f"./images/{key}/{plot}.png")
Upvotes: 1