计算多项式logit模型的预测概率

时间:2020-01-27 20:38:51

标签: python pandas numpy mlogit

请尝试提供参数化解决方案(有三个以上的替代方案)。

我有一个包含Beta值的字典:

{'B_X1': 2.0, 'B_X2': -3.0}

此数据框:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789
   6.75    4.69    9.59    5.52    9.69    7.40
   7.46    4.94    3.01    1.78    1.38    4.68
   2.05    7.30    4.08    7.02    8.24    8.49
   5.60    7.88    8.11    5.98    4.60    1.39
   1.80    8.28    9.16    7.34    7.69    6.16
   3.73    6.93    8.93    2.58    3.48    6.04
   8.06    8.88    7.06    6.76    4.68    7.82
   5.00    7.29    5.86    3.92    5.67    4.10
   2.49    2.55    4.66    7.15    6.26    7.87
   1.50    3.35    5.70    9.86    4.83    1.17
   8.19    7.72    9.56    6.61    4.15    3.64
   2.43    9.54    9.15    4.41    9.18    7.85
   2.71    3.24    4.56    6.22    7.89    9.93
   5.96    4.34    5.26    8.63    9.81    9.40

123456789是替代方案。

我想使用以下公式计算预测概率: enter image description here

jks是上述选择。

预期结果:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
   7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
   2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
   2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

每行的概率总和应为1。

请尝试提供参数化解决方案(有三个以上的替代方案)。

每个替代方案具有恒定的预期结果: {'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.440  0.000  0.560
   7.46    4.94    3.01    1.78    1.38    4.68  0.977  0.023  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.021  0.952  0.027
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.180  0.102  0.717
   2.49    2.55    4.66    7.15    6.26    7.87  0.034  0.604  0.363
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.034  0.034  0.932
   2.71    3.24    4.56    6.22    7.89    9.93  0.978  0.021  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.970  0.001  0.029

1 个答案:

答案 0 :(得分:2)

IIUC:

将列转换为MultiIndex

df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)

并定义您的B,以使键与df中的前缀匹配

B = {'X1': 2.0, 'X2': -3.0}

然后

def f(b, x):
    return np.exp((b * x).sum(1))

parts = f(B, df.stack()).unstack()

preds = parts.div(parts.sum(1), axis=0)

df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
    lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

包装成一个漂亮的功能

def f(df, b):
    d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
    parts = np.exp(d.stack().mul(b).sum(1).unstack())
    preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
    d = d.join(preds)
    d.columns = list(map('_'.join, d.columns))
    return d

f(df, B)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024