Alex
Alex

Reputation: 320

Load a Python XGBoost as SparkXGBoost?

Is there a way to take a model object trained in base XGBoost and load it as a SparkXGBoost model? The docs aren't super clear on this split. I've tried:

from xgboost.spark import SparkXGBClassifierModel
model2 = SparkXGBClassifierModel.load("xgboost-model")

Im getting the following error:

Input path does not exist: /xgboost-model/metadata

Assuming this means there is a format difference if the model had originally been trained as a SparkXGBoost model.

Upvotes: 1

Views: 1581

Answers (1)

Kevin Kho
Kevin Kho

Reputation: 687

This is a common scenario where you train a model on a smaller dataset and want to scale inference to significantly more data. As long as you have scikit-learn or xgboost installed on the workers, you can just use the original model with a pandas UDF. I think there shouldn't be a need to re-write training code.

I will use the Fugue, which provides a simple interface to demo the concept. I will use LinearRegression just because it's a lot easier to test with but any model that is serializable should be able to work. If it's not serializable like deep learning models, it can work also with a bit of tweaking. You just have to load the model on the worker instead.

Some setup:

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

X = pd.DataFrame({"x_1": [1, 1, 2, 2], "x_2":[1, 2, 2, 3]})
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression().fit(X, y)

Next we wrap the prediction logic in a function

# define our predict function. you can also load the model here
def predict(df: pd.DataFrame, model: LinearRegression) -> pd.DataFrame:
    return df.assign(predicted=model.predict(df))

# create test data
input_df = pd.DataFrame({"x_1": [3, 4, 6, 6], "x_2":[3, 3, 6, 6]})

# test the predict function
predict(input_df, reg)

Now to bring it to Spark, all we need to use is Fugue's transform() function. If you don't want to use Fugue, you can use pandas_UDFs, it will just be more code.

import fugue.api as fa
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
# you can load your data into Spark directly here

sdf = spark.createDataFrame(input_df)
res = fa.transform(sdf, 
             predict, 
             schema="*, predicted:double", 
             params=dict(model=reg))

# res is a Spark DataFrame
res.show()

And this will execute on Spark.

Upvotes: 1

Related Questions