我正在使用SVM从嘈杂的输入中学习正弦波。我已经尝试了许多不同的超参数,但该模型似乎仍然过度拟合。我不确定我是否只是要求模型做更多的事情,或者我只是选择了错误的超参数。也许有一个更好的模型我可以用于这个任务?如果有帮助,我可以显示输入和输出的图形。这是我的代码:
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
def svr(x, y, C=1, gamma=.1):
# train/test split
split = int(y.shape[0] * 0.8)
x_train, x_test = x[:split], x[split:]
y_train, y_test = y[:split], y[split:]
# fit SVM
svrm = SVR(kernel='rbf', C=C, gamma=gamma)
svr_fit = svrm.fit(x_train, y_train.flatten())
# make predictions on train and test sets
y_fitted = svr_fit.predict(x_train)
y_predicted = svr_fit.predict(x_test)
print('Done!')
return y_fitted, y_predicted
lin = np.linspace(0, 100, 5000)
rand = np.random.random(lin.shape)
sin = np.sin(lin)
x = (sin + rand/2 - 0.25).reshape((-1, 1))
y = sin.reshape((-1, 1))
print(x.shape, y.shape)
plt.plot(lin, sin)
plt.show()
plt.plot(lin[:500], x.flatten()[:500])
plt.show()
y_fit, y_pred = svr(x, y, C=1, gamma=.1)
plt.plot(lin[:len(y_fit)], y_fit)
plt.show()
plt.plot(lin[len(y_fit):], y_pred)
plt.show()