Reputation: 11
I'm working on training Gaussian process classifiers on a dataset of interest, and I thought that moving from scikit-learn to skorch (which enables using GPU acceleration via GPytorch) would significantly speed up model training and possibly improve accuracy given that many researchers seem to prefer GPytorch over scikit-learn now. Instead, model accuracy was similar but training with skorch was more than an order of magnitude slower, even when training on GPU instead of CPU.
Am I doing something wrong in my use of skorch that is making performance worse? Is it just that the different numerical approximations that GPytorch and scikit-learn use for GP classifier training perform differently depending on the dataset?
Here's my GPytorch module and model hyperparameters (empirically tuned in a somewhat ad hoc way) for use with skorch:
class GPClassificationModel(ApproximateGP):
def __init__(self, train_x, init_lengthscale = 3.):
variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
variational_strategy = UnwhitenedVariationalStrategy(
self, train_x, variational_distribution, learn_inducing_locations=False
)
super(GPClassificationModel, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
self.covar_module.base_kernel.lengthscale = init_lengthscale
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# print(latent_pred.mean)
return latent_pred
model = GPBinaryClassifier(GPClassificationModel, module__train_x = X_train, criterion__num_data=len(X_train), device = device, batch_size = 32, lr=0.5, max_epochs=400)
For scikit-learn, I'm just using the GaussianProcessClassifier model class with an RBF kernel.
My inputs are vectors of 640 continuous features and output is a binary classification, if that helps. I've tried a dataset with ~600 training data points and one with ~6000 training data points and observed the same speed disparity with both. I'd be really grateful for any advice on how to improve model performance with GPytorch, and/or why scikit-learn might be faster in some situations.
Upvotes: 1
Views: 166