Stu
Stu

Reputation: 1643

Best way to map variables to various input layers

I have 25 mixed variables: some are binary, some continuous, and most are high-level factors to be embedded.

Then there is my deep learning model, which takes the many inputs and builds an autoencoder around it, which can be seen below.

My question is, how do I map the variables from my pandas data frame to the appropriate input layers? For instance, each high level factor to go to the right embedding layer. First thoughts are to arrange data set in correct order of inputs before doing any training or somehow map the variables (e.g. provID -> input_provid).

autoencoder.summary()
Model: "claims_ae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_provid (InputLayer)       [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_pos_code (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_prindiag (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_billtype2 (InputLayer)    [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_lob (InputLayer)          [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_ppg_code (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_segment (InputLayer)      [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_dofr (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding_8 (Embedding)         (None, 1, 70)        337820      input_provid[0][0]               
__________________________________________________________________________________________________
embedding_9 (Embedding)         (None, 1, 6)         168         input_pos_code[0][0]             
__________________________________________________________________________________________________
embedding_10 (Embedding)        (None, 1, 77)        447524      input_prindiag[0][0]             
__________________________________________________________________________________________________
embedding_11 (Embedding)        (None, 1, 5)         95          input_billtype2[0][0]            
__________________________________________________________________________________________________
embedding_12 (Embedding)        (None, 1, 3)         24          input_lob[0][0]                  
__________________________________________________________________________________________________
embedding_13 (Embedding)        (None, 1, 10)        930         input_ppg_code[0][0]             
__________________________________________________________________________________________________
embedding_14 (Embedding)        (None, 1, 2)         8           input_segment[0][0]              
__________________________________________________________________________________________________
embedding_15 (Embedding)        (None, 1, 3)         21          input_dofr[0][0]                 
__________________________________________________________________________________________________
input_number_features (InputLay [(None, 4)]          0                                            
__________________________________________________________________________________________________
input_binary_features (InputLay [(None, 12)]         0                                            
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 70)           0           embedding_8[0][0]                
__________________________________________________________________________________________________
reshape_9 (Reshape)             (None, 6)            0           embedding_9[0][0]                
__________________________________________________________________________________________________
reshape_10 (Reshape)            (None, 77)           0           embedding_10[0][0]               
__________________________________________________________________________________________________
reshape_11 (Reshape)            (None, 5)            0           embedding_11[0][0]               
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 3)            0           embedding_12[0][0]               
__________________________________________________________________________________________________
reshape_13 (Reshape)            (None, 10)           0           embedding_13[0][0]               
__________________________________________________________________________________________________
reshape_14 (Reshape)            (None, 2)            0           embedding_14[0][0]               
__________________________________________________________________________________________________
reshape_15 (Reshape)            (None, 3)            0           embedding_15[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 192)          0           input_number_features[0][0]      
                                                                 input_binary_features[0][0]      
                                                                 reshape_8[0][0]                  
                                                                 reshape_9[0][0]                  
                                                                 reshape_10[0][0]                 
                                                                 reshape_11[0][0]                 
                                                                 reshape_12[0][0]                 
                                                                 reshape_13[0][0]                 
                                                                 reshape_14[0][0]                 
                                                                 reshape_15[0][0]                 
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 16)           3088        concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 8)            136         dense_8[0][0]                    
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 4)            36          dense_9[0][0]                    
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 2)            10          dense_10[0][0]                   
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 4)            12          dense_11[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 8)            40          dense_12[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 16)           144         dense_13[0][0]                   
__________________________________________________________________________________________________
dense_15 (Dense)                (None, 192)          3264        dense_14[0][0]                   
==================================================================================================
Total params: 793,320
Trainable params: 793,320
Non-trainable params: 0

Upvotes: 1

Views: 126

Answers (1)

Georgios Livanos
Georgios Livanos

Reputation: 536

From tf.keras Documentation

x: Input data. It could be:

  • A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
  • A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
  • A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
  • A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
  • A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights).

So, you can pass the attr name to the Input Layers and then provide the input data with a dict that have keys names as the inputs.

Upvotes: 3

Related Questions