Reputation: 348
I am trying to create an API for customer churn at a bank. I have completed the model and now want to create the API using FastAPI. My problem is converting the JSON passed data to a dataframe to be able to run it through the model. Here is the code.
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from pycaret.classification import *
import pandas as pd
import uvicorn # ASGI
import pickle
import pydantic
from pydantic import BaseModel
class customer_input(BaseModel):
CLIENTNUM:int
Customer_Age:int
Gender:str
Dependent_count:int
Education_Level:str
Marital_Status:str
Income_Category:str
Card_Category:str
Months_on_book:int
Total_Relationship_Count:int
Months_Inactive_12_mon:int
Contacts_Count_12_mon:int
Credit_Limit:float
Total_Revolving_Bal:int
Avg_Open_To_Buy:float
Total_Amt_Chng_Q4_Q1:float
Total_Trans_Amt:int
Total_Trans_Ct:int
Total_Ct_Chng_Q4_Q1:float
Avg_Utilization_Ratio:float
app = FastAPI()
#Loading the saved model from pycaret
model = load_model('BankChurnersCatboostModel25thDec2020')
origins = [
'*'
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=['GET','POST'],
allow_headers=['Content-Type','application/xml','application/json'],
)
@app.get("/")
def index():
return {"Nothing to see here"}
@app.post("/predict")
def predict(data: customer_input):
# Convert input data into a dictionary
data = data.dict()
# Convert the dictionary into a dataframe
my_data = pd.DataFrame([data])
# Predicting using pycaret
prediction = predict_model(model, my_data)
return prediction
# Only use below 2 lines when testing on localhost -- remove when deploying
if __name__ == '__main__':
uvicorn.run(app, host='127.0.0.1', port=8000)
When I test this out I get the Internal Server Error from the OpenAPI interface so I check my cmd and the error says
ValueError: [TypeError("'numpy.int64' object is not iterable"), TypeError('vars() argument must have __dict__ attribute')]
How can I have the data that is passed into the predict function successfully convert into a dataframe. Thank you.
Upvotes: 2
Views: 3190
Reputation: 348
Ok so I fixed this by changing the customer_input
class. Any int
types I changed to a float
and that fixed it. I don't understand why though. Can anyone explain?
Fundamentally those int
values are only meant to be an integer because they are all discrete values (i.e choosing number of dependents in a bank) but I guess I could put a constrain on the front-end.
Upvotes: 2