如何修复我的scikit-learn程序?

时间:2015-06-04 11:47:18

标签: python python-2.7 pandas scikit-learn random-forest

使用python机器学习库scikit-learn的RandomForestClassifier代码片段。

我正在尝试使用scikit的RandomForestClassifier中的class_weight opition为不同的类赋予权重.Below是我的代码片段,然后是我得到的错误

print 'Training...'
forest = RandomForestClassifier(n_estimators=500,class_weight= {0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:4})
forest = forest.fit( train_data[0::,1::], train_data[0::,0] )

print 'Predicting...'
output = forest.predict(test_data).astype(int)


predictions_file = open("myfirstforest.csv", "wb")
open_file_object = csv.writer(predictions_file)
open_file_object.writerow(["PassengerId","Survived"])
open_file_object.writerows(zip(ids, output))
predictions_file.close()
print 'Done.'

我收到以下错误:

Training...

IndexError                                Traceback (most recent call last)
<ipython-input-20-122f2e5a0d3b> in <module>()
 84 print 'Training...'
 85 forest = RandomForestClassifier(n_estimators=500,class_weight={0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:4})
---> 86 forest = forest.fit( train_data[0::,1::], train_data[0::,0] )
 87 
 88 print 'Predicting...'

/home/rpota/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in fit(self, X, y, sample_weight)
216         self.n_outputs_ = y.shape[1]
217 
--> 218         y, expanded_class_weight = self._validate_y_class_weight(y)
219 
220         if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:

/home/rpota/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in _validate_y_class_weight(self, y)
433                     class_weight = self.class_weight
434                 expanded_class_weight = compute_sample_weight(class_weight,
--> 435                                                               y_original)
436 
437         return y, expanded_class_weight

/home/rpota/anaconda/lib/python2.7/site-packages/sklearn/utils/class_weight.pyc in compute_sample_weight(class_weight, y, indices)
150             weight_k = compute_class_weight(class_weight_k,
151                                             classes_full,
--> 152                                             y_full)
153 
154         weight_k = weight_k[np.searchsorted(classes_full, y_full)]

/home/rpota/anaconda/lib/python2.7/site-packages/sklearn/utils/class_weight.pyc in compute_class_weight(class_weight, classes, y)
 58         for c in class_weight:
 59             i = np.searchsorted(classes, c)
---> 60             if classes[i] != c:
 61                 raise ValueError("Class label %d not present." % c)
 62             else:

IndexError: index 2 is out of bounds for axis 0 with size 2

请帮忙!

0 个答案:

没有答案