我可以使用sklearn的多输入模型(keras多输入模型)

时间:2016-07-15 12:46:32

标签: scikit-learn keras

我想用sklearn交叉验证功能训练多输入keras模型。但我失败了。 sklearn不支持多输入模型?还有其他方法可以克服它吗?

这些是输入形状

for sheet in tower_dataset :
    print sheet.shape

输出

(126L, 3L)
(126L, 45L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 148L)
(126L, 100L)
(126L, 296L)
(126L, 296L)
(126L, 176L)
(126L, 31L)
(126L, 5L)

而我试图做的是用sklearn交叉验证训练keras模型。

epochs = 100
n_folds = 10
model = KerasClassifier(build_fn=create_model, nb_epoch=150, batch_size=10)
skf = StratifiedKFold(y=label, n_folds=n_folds, shuffle=True, random_state=rand_seed)
kfold = StratifiedKFold(y=label, n_folds=n_folds, shuffle=True, random_state=rand_seed)

错误输出

results = cross_val_score(model, tower_dataset, label, cv=kfold)
print(results.mean())
ValueErrorTraceback (most recent call last)
<ipython-input-219-924ce2fda183> in <module>()
      6 kfold = StratifiedKFold(y=label, n_folds=n_folds, shuffle=True, random_state=rand_seed)
      7 
----> 8 results = cross_val_score(model, tower_dataset, label, cv=kfold)
      9 print(results.mean())
     10 

C:\Users\user\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in cross_val_score(estimator, X, y, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch)
   1420         Array of scores of the estimator for each run of the cross validation.
   1421     """
-> 1422     X, y = indexable(X, y)
   1423 
   1424     cv = check_cv(cv, X, y, classifier=is_classifier(estimator))

C:\Users\user\Anaconda2\lib\site-packages\sklearn\utils\validation.pyc in indexable(*iterables)
    199         else:
    200             result.append(np.array(X))
--> 201     check_consistent_length(*result)
    202     return result
    203 

C:\Users\user\Anaconda2\lib\site-packages\sklearn\utils\validation.pyc in check_consistent_length(*arrays)
    174     if len(uniques) > 1:
    175         raise ValueError("Found arrays with inconsistent numbers of samples: "
--> 176                          "%s" % str(uniques))
    177 
    178 

ValueError: Found arrays with inconsistent numbers of samples: [ 16 126]
谢谢你的时间和帮助。

0 个答案:

没有答案