user3476463
user3476463

Reputation: 4575

matrix factorization model returning much smaller dataframe after predicting ratings in pyspark

I'm trying to create a product recommender with the code below. I'm using matrix factorization from spark ml. I have data that has a customer_id, product_id, and a numeric rating value that has been normalized. So all rating values are between 0 and 1. My dataset has almost 10M records with no null product_id or customer_id. Yet after I train the model and then try to make predictions for the entire dataset, I'm getting back a very small number of predictions. I'm getting similar size test_pred dataframe returned by the model whether I set the coldStartStrategy to "nan" or "drop". Does anyone see what the issue may be and can you suggest how to fix?

code:

print(data.count())

print(data[['customer_id']].distinct().count())

print(data[['product_id']].distinct().count())

output:

9943626
1715292
308792

code:

from pyspark.ml.recommendation import ALS

(training, test) = data.randomSplit([0.7, 0.3])

als = ALS(implicitPrefs=True,
          maxIter=5, 
          regParam=0.01, 
          userCol="customer_id_index", 
          itemCol="product_id_index", 
          ratingCol="rating",
         coldStartStrategy="nan")
model = als.fit(training)



test_pred=model.transform(data)

print(test_pred.count())

print(test_pred[['customer_id']].distinct().count())

print(test_pred[['product_id']].distinct().count())

output:

3346
522
760

Upvotes: 0

Views: 205

Answers (1)

user3476463
user3476463

Reputation: 4575

looks like the issue was that I needed to save my data after I initially converted one of the id fields to a string index. seems like after I trained the model and made predictions, the pyspark code must've changed the indexes for the records. so it couldn't find the same customer product pairs to make predictions for. Once I saved the data out after initially string indexing and then read back in and applied transform to it with the trained model, I got much higher record count for the resulting dataframe.

Upvotes: 1

Related Questions