DeepLearning4j和DataVec读取带标签的csv文件

时间:2017-08-07 08:00:28

标签: deeplearning4j

我已经建立了一个DL4j项目。如果我使用MNIST数据集,一切都很好:

    DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
    DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);

但是,我想用以下格式切换到我自己的csv文件:

A  |  B  |  C  |  X  |  Y
-------------------------
1  | 100 |  5  |  15 |  6
...

XY是结果(或标签)。由于我计划执行回归分析,因此XY都是实数。所以我使用以下代码读取了csv文件:

    RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
    recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
代码中的

3表示index of the labels2表示number of possible labels。关于这两个参数没有太多解释。我猜他们的意思是标签从第4列开始,有2个标签。

当我运行代码时,它显示以下异常:

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14

我认为这是因为dl4j无法将15识别为标签。

所以我的问题是:如何才能正确读取csv文件以进行回归分析?

非常感谢。

1 个答案:

答案 0 :(得分:1)

是的,我们有回归的例子: https://github.com/deeplearning4j/dl4j-examples/tree/cc383de91bdf4e28e36859aa2e8749100cd63177/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/regression

您需要将回归true(它是构造函数的额外部分)传递给RecordReaderDataSetIterator。