CNN输出2类的百分比

时间:2020-07-16 15:22:58

标签: tensorflow keras cnn transfer-learning

我是争论的初学者。我有这个问题:我必须在视频的每一帧中对2类的百分比进行分类。 我创建了一个包含约500张图像(每个类别250张)的小型数据集,以及一个包含以下图层的CNN:

model = tf.models.Sequential()
model.add(tf.layers.Conv2D(32, kernel_size=(3, 3), activation='relu',input_shape=(224,224,3)))
model.add(tf.layers.MaxPooling2D((2, 2)))
model.add(tf.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.layers.MaxPooling2D((2, 2)))
model.add(tf.layers.Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(tf.layers.MaxPooling2D((2, 2)))
model.add(tf.layers.Conv2D(256, kernel_size=(3, 3), activation='relu'))
model.add(tf.layers.MaxPooling2D((2, 2)))
model.add(tf.layers.Flatten())
model.add(tf.layers.Dense(512, activation='relu'))
model.add(tf.layers.Dropout(0.2))
model.add(tf.layers.Dense(2,activation='sigmoid'))
model.summary()
model.compile(loss='binary_crossentropy', optimizer=tf.optimizers.Adam(learning_rate=0.00001), metrics=['accuracy'])

1)最好使用 binary_crossentropy + Sigmoid binary_crossentropy + softmax

2)那么最好使用传输学习/微调还是像这样从头开始构建CNN?

3)我正在使用ImageDataGenerator进行 DataAugmentation ,因为数据集很小,对吗?

4)我可以将哪些值用于batch_size,steps_per_epochs,learning_rate ...我注意到,使用val_accuracy可以使模型精度提前到1.0,并且在预测中不能返回每个类的正确百分比,但是可以返回值像[9.999e-1 4.444e-5]

1 个答案:

答案 0 :(得分:0)

  1. 因为您的分类是二进制的,所以请使用Sigmoid。 Softmax适用于多类(> 2)。
  2. 使用转移学习总是更好。与VGG16,ResNet,Inception等一起使用。
  3. 是的,在数据​​集较小的情况下,数据增强有很大帮助。
  4. 您需要在最后一层而不是2中使用一个神经元。因为,在一个神经元中,如果值大于0.5,则将其视为1类,否则将视为0。如果要坚持使用两个神经元,则,要获得答案,您应该进行{ "hands" : [ { "first" : { "value" : 12, "suit" : { "name" : "heart" } }, "Second" : { "value" : 12, "suit" : { "name" : "spade" } } }, { "first" : { "value" : 8, "suit" : { "name" : "club" } }, "second" : { "value" : 9, "suit" : { "name" : "club" } } } ]} 的预测,在给出的示例Resolved [org.springframework.http.converter.HttpMessageNotReadableException: Could not read JSON: java.lang.IllegalStateException: Expected a string but was BEGIN_OBJECT at line 4 column 15 path $.hands[0].first; nested exception is com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected a string but was BEGIN_OBJECT at line 4 column 15 path $.hands[0].first] 中,预测类别为0,如pred [0]> pred [1]。
相关问题