cs0815
cs0815

Reputation: 17388

data format to predict with model fitted via Sagemaker's XGBoost built-in algorithm and training container

Looking at the following code, taken from here, I wonder what format dtest is (sorry I could not gleen this from the post):

import pickle as pkl 
import tarfile

t = tarfile.open('model.tar.gz', 'r:gz')
t.extractall()

model = pkl.load(open(model_file_path, 'rb'))

# prediction with test data
pred = model.predict(dtest)

In my case the training and validation data are in csv format coming from a S3 bucket:

content_type = "csv"
train_input = TrainingInput("s3://{}/{}/{}/".format(bucket, prefix, 'train'), content_type=content_type)

So ideally, I would also like to use the same format for scoring/prediction/inference.

PS:

This little function appears to work fine:

def write_prediction_data(data_file_name, target_name, model_file_name, output_file_name):

    model = pkl.load(open(model_file_name, 'rb'))
    data = pd.read_csv(data_file_name) 
    target = data[target_name]
    data = data.drop([target_name], axis=1)
    xgb_data = xgb.DMatrix(data.values, target.values)

    data = pd.read_csv(data_file_name)
    data['Prediction'] = model.predict(xgb_data)

    data.to_csv(output_file_name, index=False)

Improvement suggestions always welcome (-:

Upvotes: 0

Views: 351

Answers (1)

CrzyFella
CrzyFella

Reputation: 191

The "dtest" format will be a csv without any label column. There is no specific format it requires except that it is handled properly by the model.

Upvotes: 1

Related Questions