Shubham Goel
Shubham Goel

Reputation: 1

How can I save the explain() output for H2O models

I am trying to save the output images (graphs) I get when I use explain() in H2O models. Currently I am just saving the SHAP output using the model.shap_summary_plot(test, save_plot_path=`shap_summary.png`). There is no save_plot_path for explain.

import h2o
from h2o.automl import H2OAutoML

h2o.init()

df = h2o.import_file("https://h2o-public-test-data.s3.amazonaws.com/smalldata/wine/winequality-redwhite-no-BOM.csv")

response = "quality"

predictors = [
  "fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", "free sulfur dioxide",
  "total sulfur dioxide", "density", "pH", "sulphates", "alcohol",  "type"
]


train, test = df.split_frame(seed=1)

aml = H2OAutoML(max_runtime_secs=120, seed=1)
aml.train(x=predictors, y=response, training_frame=train)

leader_model = aml.leader 

leader_model.explain(test) # save this output

However I want to save all the graphs generated via explain() instead of creating them individually. Also I want it to run as a script and not as a jupyter notebook.

Here is sample code,(edited Explain-wine-example)

H2O explain docs

Upvotes: 0

Views: 821

Answers (1)

Mathanraj-Sharma
Mathanraj-Sharma

Reputation: 356

In H2O3, the model.explain() will return h2o.explanation._explain.H2OExplanation object. You can iterate through it to save your plots.

param render, if True, render the model explanations otherwise model explanations are just returned.

I was able to do it with the below function

tested with h2o version '3.36.1.2'

def save_explain_plots(model, data):
    obj = model.explain(data, render=False)
    for key in obj.keys():
        print(f"saving {key} plots")
        if not obj.get(key).get("plots"):
            continue
        plots = obj.get(key).get("plots").keys()

        os.makedirs(f"./images/{key}", exist_ok=True)
        for plot in plots:
            fig = obj.get(key).get("plots").get(plot).figure()
            fig.savefig(f"./images/{key}/{plot}.png")

Upvotes: 1

Related Questions