带有TF数据集问题的TF keras API-steps_per_epoch参数问题

时间:2019-06-13 12:30:05

标签: python tensorflow deep-learning tensorflow-datasets

当尝试使用tensorflow.keras API和tf.Dataset诱导迭代器来拟合Keras模型时,该模型会抱怨steps_per_epoch参数,即使我已将其设置为具体参数值。

这是我的模型课

import tensorflow as tf
import numpy as np
from typing import Union, List
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import layers
from tftools import TFTools


class TestServe():
    def __init__(self, tfrecords: Union[List[tf.train.Example], tf.train.Example], batch_size: int = 10, input_shape: tuple = (64, 23)) -> None:
        self.tfrecords = tfrecords
        self.batch_size = batch_size
        self.input_shape = input_shape

    def get_model(self):
        ins = layers.Input(shape=(64, 23))

        l = layers.Reshape((*self.input_shape, 1))(ins)
        l = layers.Conv2D(8, (30, 23), padding='same', activation='relu')(l)
        l = layers.MaxPool2D((4, 5), strides=(4, 5))(l)
        l = layers.Conv2D(16, (3, 3), padding='same', activation='relu')(l)
        l = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(l)
        l = layers.MaxPool2D((2, 2), strides=(2, 2))(l)
        l = layers.Flatten()(l)

        out = layers.Dense(1, activation='softmax')(l)
        return tf.keras.models.Model(ins, out)

    def train(self):

        # Create Dataset
        dataset = TFTools.create_dataset(self.tfrecords)
        dataset = dataset.repeat(6).batch(self.batch_size)

        val_iterator = dataset.take(300).make_one_shot_iterator()
        train_iterator = dataset.skip(300).make_one_shot_iterator()

        model = self.get_model()
        model.summary()
        model.compile(optimizer='rmsprop',
                      loss='binary_crossentropy', metrics=['accuracy'])
        model.fit(train_iterator, validation_data=val_iterator,
                  epochs=10, verbose=1, steps_per_epoch=20)

    def predict(self, X: np.array) -> np.array:
        pass

ts = TestServe(['./ok.tfrecord', './nok.tfrecord'])
ts.train()

但是一旦我开始训练,在第一批完成之前,我会从tensorflow中得到一个例外

2019-06-13 14:22:25.393398: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1995445000 Hz
2019-06-13 14:22:25.393681: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x2f7d120 executing computations on platform Host. Devices:
2019-06-13 14:22:25.393708: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
Epoch 1/2
19/20 [===========================>..] - ETA: 0s - loss: 1.1921e-07 - acc: 1.0000Traceback (most recent call last):
  File "TestServe.py", line 62, in <module>
    ts.train()
  File "TestServe.py", line 56, in train
    epochs=2, verbose=1, callbacks=callbacks, steps_per_epoch=20) #The steps_per_epoch is typically samples_per_epoch / batch_size
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 880, in fit
    validation_steps=validation_steps)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 364, in model_iteration
    validation_in_fit=True)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 202, in model_iteration
    steps_per_epoch)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 76, in _get_num_samples_or_steps
    'steps_per_epoch')
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 230, in check_num_samples
    if check_steps_argument(ins, steps, steps_name):
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 960, in check_steps_argument
    input_type=input_type_str, steps_name=steps_name))
ValueError: When using data tensors as input to a model, you should specify the `steps_per_epoch` argument.

原始数据集包含大约1500个样本,但是我想将多个tfrecord文件连接到TFRecordDataset,所以我将没有有关长度的信息。

有人见过类似的东西吗?我不知道该向哪里寻求帮助,因为tf.keras API相对较新。 create_dataset函数仅返回使用正确的parse函数映射的数据集。

2 个答案:

答案 0 :(得分:0)

找到了解决方案。

不仅有steps_per_epoch,而且还有validation_steps参数,您也必须指定。

答案 1 :(得分:0)

当我实际上在本地安装旧版本(TensorFlow 1.14)时尝试使用TensorFlow 2.0模型时,报告了此错误。

要升级到最新的TensorFlow版本,请运行:

python -m pip install --upgrade pip
python -m pip install --upgrade tensorflow
相关问题