hoof_hearted
hoof_hearted

Reputation: 675

Import Only Work Inside Python Function

Background Info: I'm developing a model with scikit-learn. I'm splitting the data into separate training and testing sets using the sklearn.cross_validation module, as shown below:

def train_test_split(input_data):
       from sklearn.cross_validation import train_test_split

        ### STEP 1: Separate y variable and remove from X
       y = input_data['price']
       X = input_data.copy()
       X.drop('price', axis=1, inplace=True)

        ### STEP 2: Split into training & test sets
       X_train, X_test, y_train, y_test =\ 
                        train_test_split(X, y, test_size=0.2, random_state=0)
       return X_train, X_test, y_train, y_test

My Question: When I try to import the sklearn.cross_validation module outside of my function, like so, I get the following error:

from sklearn.cross_validation import train_test_split

def train_test_split(input_data):
       ### STEP 1: Separate y variable and remove from X
       y = input_data['price']
       X = input_data.copy()
       X.drop('price', axis=1, inplace=True)

       ### STEP 2: Split into training & test sets
       X_train, X_test, y_train, y_test =\ 
                        train_test_split(X, y, test_size=0.2, random_state=0)
       return X_train, X_test, y_train, y_test

Error:

TypeError: train_test_split() got an unexpected keyword argument 'test_size'

Any idea why?

Upvotes: 2

Views: 386

Answers (1)

RedX
RedX

Reputation: 15175

You are importing the function train_test_split from sklear.cross_validation and then overriding the name with your local function train_test_split.

Try:

from sklearn.cross_validation import train_test_split as sk_train_test_split

def train_test_split(input_data):
       ### STEP 1: Separate y variable and remove from X
       y = input_data['price']
       X = input_data.copy()
       X.drop('price', axis=1, inplace=True)

       ### STEP 2: Split into training & test sets
       X_train, X_test, y_train, y_test =\ 
                        sk_train_test_split(X, y, test_size=0.2, random_state=0)  # use the imported function instead of local one
       return X_train, X_test, y_train, y_test

Upvotes: 4

Related Questions