如何使用自己的RBF函数替换sklearn中的svm.SVR'rbf'内核?

时间:2019-04-20 06:31:16

标签: python matrix scikit-learn svm

我已经开发了以下代码来启动svm方法的项目:

import numpy as np
import pandas as pd

from sklearn import svm
from sklearn.datasets import load_boston
from sklearn.metrics import mean_absolute_error

housing = load_boston()
df = pd.DataFrame(np.c_[housing['data'], housing['target']],
              columns= np.append(housing['feature_names'], ['target']))

features = df.columns.tolist()
label = features[-1]
features = features[:-1]

x_train = df[features].iloc[:400]
y_train = df[label].iloc[:400]

x_test = df[features].iloc[400:]
y_test = df[label].iloc[400:]

svr = svm.SVR(kernel='rbf')
svr.fit(x_train, y_train)
y_pred = svr.predict(x_test)

print(mean_absolute_error(y_pred, y_test))

现在,我要使用定制的rbf内核:

def my_rbf(feat, lbl):
#feat = feat.values
    #lbl = lbl.values
    ans = np.array([])
    gamma = 0.000005
    for i in range(len(feat)):
        ans = np.append(ans, np.exp(-gamma * np.dot(feat[i]-lbl[i], feat[i]-lbl[i])))

    return ans

然后我更改了svm.SVR(kernel=my_rbf),但是在以任何方式对其进行修改时都会遇到很多错误。我还尝试使用像np.dot(feat-lbl,feat-lbl)这样的简单函数,该函数在SVR.fit方法中运行良好,但是在svr.predict中发生了一些错误,该错误表示输入矩阵的形状必须类似于[n_samples_test,n_samples_train] 。

我被迫处理这些错误。谁能帮助我使此代码正常工作?

2 个答案:

答案 0 :(得分:1)

您编写的自定义内核方法my_rbf同时使用X(功能)和y(标签)。由于无法访问标签,因此无法在预测期间评估此功能。自定义内核有缺陷。

背景

RBF函数定义如下(来自wiki

enter image description here

其中xx'是两个特征(X)向量。

H(X)是一个函数,用于将向量X转换为其他维度(通常转换为非常高的维度)。 SVM需要计算特征向量的所有组合(即所有H(X))之间的点积。因此,如果H(X1) . H(X2) = K(X1, X2),则K被称为H的内核函数或内核化。因此,X1无需直接从X2K进行计算,而不是将点X1X2转换成非常高的尺寸并在那里计算点积。

结论 my_rbf不是有效的内核函数,仅因为它使用标签(Y s)。它应该仅在特征向量上。

答案 1 :(得分:0)

根据this source,我正在寻找的RBF函数(将训练特征作为X并将测试特征作为X'作为输入)并输出[n_training_samples,n_testing_samples],如docs中所述,是这样的:

def my_kernel(X,Y):
    K = np.zeros((X.shape[0],Y.shape[0]))
    for i,x in enumerate(X):
        for j,y in enumerate(Y):
            K[i,j] = np.exp(-1*np.linalg.norm(x-y)**2)
    return K

clf=SVR(kernel=my_kernel)

结果完全等于:

clf=SVR(kernel="rbf",gamma=1)

在速度方面,它缺乏像默认svm库rbf一样有效的性能。将cython库的static typing用于索引,并将memory-views用于numpy数组可能会有所帮助。