Reputation: 17388
Looking at the following code, taken from here, I wonder what format dtest is (sorry I could not gleen this from the post):
import pickle as pkl
import tarfile
t = tarfile.open('model.tar.gz', 'r:gz')
t.extractall()
model = pkl.load(open(model_file_path, 'rb'))
# prediction with test data
pred = model.predict(dtest)
In my case the training and validation data are in csv format coming from a S3 bucket:
content_type = "csv"
train_input = TrainingInput("s3://{}/{}/{}/".format(bucket, prefix, 'train'), content_type=content_type)
So ideally, I would also like to use the same format for scoring/prediction/inference.
PS:
This little function appears to work fine:
def write_prediction_data(data_file_name, target_name, model_file_name, output_file_name):
model = pkl.load(open(model_file_name, 'rb'))
data = pd.read_csv(data_file_name)
target = data[target_name]
data = data.drop([target_name], axis=1)
xgb_data = xgb.DMatrix(data.values, target.values)
data = pd.read_csv(data_file_name)
data['Prediction'] = model.predict(xgb_data)
data.to_csv(output_file_name, index=False)
Improvement suggestions always welcome (-:
Upvotes: 0
Views: 351
Reputation: 191
The "dtest" format will be a csv without any label column. There is no specific format it requires except that it is handled properly by the model.
Upvotes: 1