qwerty
qwerty

Reputation: 887

Calculating multinomial logit model prediction probabilities

Please try to give parameterize solution (there are more than three alternatives).

I have a dict with beta values:

{'B_X1': 2.0, 'B_X2': -3.0}

And this data frame:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789
   6.75    4.69    9.59    5.52    9.69    7.40
   7.46    4.94    3.01    1.78    1.38    4.68
   2.05    7.30    4.08    7.02    8.24    8.49
   5.60    7.88    8.11    5.98    4.60    1.39
   1.80    8.28    9.16    7.34    7.69    6.16
   3.73    6.93    8.93    2.58    3.48    6.04
   8.06    8.88    7.06    6.76    4.68    7.82
   5.00    7.29    5.86    3.92    5.67    4.10
   2.49    2.55    4.66    7.15    6.26    7.87
   1.50    3.35    5.70    9.86    4.83    1.17
   8.19    7.72    9.56    6.61    4.15    3.64
   2.43    9.54    9.15    4.41    9.18    7.85
   2.71    3.24    4.56    6.22    7.89    9.93
   5.96    4.34    5.26    8.63    9.81    9.40

123, 456, and 789 are the alternatives.

I want to calculate the prediction probability using this formula: enter image description here

j, k, and s are the mentioned alternatives.

Expected result:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
   7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
   2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
   2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Probabilities sum should be 1 in every row.

Please try to give parameterize solution (there are more than three alternatives).

Expected result with constant for each alternative: {'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.440  0.000  0.560
   7.46    4.94    3.01    1.78    1.38    4.68  0.977  0.023  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.021  0.952  0.027
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.180  0.102  0.717
   2.49    2.55    4.66    7.15    6.26    7.87  0.034  0.604  0.363
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.034  0.034  0.932
   2.71    3.24    4.56    6.22    7.89    9.93  0.978  0.021  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.970  0.001  0.029

Upvotes: 2

Views: 395

Answers (1)

piRSquared
piRSquared

Reputation: 294248

IIUC:

Turn columns into a MultiIndex

df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)

And define your B such that the keys match the prefixes in df

B = {'X1': 2.0, 'X2': -3.0}

Then

def f(b, x):
    return np.exp((b * x).sum(1))

parts = f(B, df.stack()).unstack()

preds = parts.div(parts.sum(1), axis=0)

df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
    lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Wrapped in one pretty function

def f(df, b):
    d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
    parts = np.exp(d.stack().mul(b).sum(1).unstack())
    preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
    d = d.join(preds)
    d.columns = list(map('_'.join, d.columns))
    return d

f(df, B)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Upvotes: 2

Related Questions