使用神经网络进行分类

时间:2018-09-17 16:44:53

标签: r neural-network rstudio

我正在尝试将神经网络与neuralnet包一起使用,以作为具有二进制响应的基础。问题在于,显然仅适用于数字响应。

# Load data
data(cats,package = "MASS")
str(cats)
'data.frame':   144 obs. of  3 variables:
 $ Sex: Factor w/ 2 levels "F","M": 1 1 1 1 1 1 1 1 1 1 ...
 $ Bwt: num  2 2 2 2.1 2.1 2.1 2.1 2.1 2.1 2.1 ...
 $ Hwt: num  7 7.4 9.5 7.2 7.3 7.6 8.1 8.2 8.3 8.5 ...

适合神经元网模型

library(neuralnet)
nn <- neuralnet(formula = Sex ~ Bwt + Hwt, data = cats)
Error in neurons[[i]] %*% weights[[i]] : 
  requires numeric/complex matrix/vector arguments

一些建议使用neuralnet软件包来调整和预测具有binara响应的变量。

2 个答案:

答案 0 :(得分:1)

您可以将因子转换为二进制数据:

cats$Sex.binary <- as.numeric(cats$Sex) - 1
table(cats$Sex.binary)
 0  1
47 97

nn <- neuralnet(formula = Sex.binary ~ Bwt + Hwt, data = cats)

然后使用模型进行预测:

new.cats.data <- data.frame(Bwt=2, Hwt=2)
nn.pred <- compute(nn, new.cats.data)
nn.pred$net.result
ifelse(nn.pred$net.result > 0.5, 1, 0)

请注意,0.5可能不是此数据的最佳分类阈值。

答案 1 :(得分:1)

显然,一种方法是返回数字或整数类型的变量,问题是在进行预测时,它不会抛出整数。但是,可以使用ifelse重新计算预测,以获得适当的结果。

cats$Sex <- as.integer(cats$Sex)-1
nn <- neuralnet(formula = Sex ~ Bwt + Hwt, data = cats, hidden=3)
plot(nn)

enter image description here

pred.nn <- compute(nn, cats[,-1])
res <- ifelse(pred.nn$net.result > 0.5,1,0)
caret::confusionMatrix(as.factor(res),as.factor(cats$Sex))
Confusion Matrix and Statistics

          Reference
Prediction  0  1
         0 31 11
         1 16 86

               Accuracy : 0.8125                
                 95% CI : (0.7390483, 0.8726502)
    No Information Rate : 0.6736111             
    P-Value [Acc > NIR] : 0.0001470219          

                  Kappa : 0.5615697             
 Mcnemar's Test P-Value : 0.4414183268          

            Sensitivity : 0.6595745             
            Specificity : 0.8865979             
         Pos Pred Value : 0.7380952             
         Neg Pred Value : 0.8431373             
             Prevalence : 0.3263889             
         Detection Rate : 0.2152778             
   Detection Prevalence : 0.2916667             
      Balanced Accuracy : 0.7730862             

       'Positive' Class : 0