滑动游戏拼图的深度学习训练模型是什么?

时间:2019-05-09 14:11:05

标签: python tensorflow training-data

我正在尝试构建一个模型,以通过Python滑动游戏谜题训练深度学习神经网络,并想知道我们可以使用tensorflow构建的最佳模型是什么,为什么?我需要使用哪种层?过滤器值,kernel_size和其他args的值是什么,在我的情况下它们实际上意味着什么? 请记住,这是我第一次尝试处理机器学习,并且有很多我不了解的词汇:)

我已经用Conv2D,Activation和Dense层(很难获得超过0.22的精度)构建了数据生成部分和第一个模型,但是当添加更多的Conv2D层时,我还遇到了一些负尺寸错误。 我指的滑动游戏难题是一个包含颜色的NxM网格(此处我们仅使用2种颜色),目标是移动行和列以使所有相同的图块相邻。移位是通过移动键完成的,该键指的是给定列或行的移位方向。

以下是游戏画面:https://www.youtube.com/watch?v=pCwELYqLAGg

要生成训练数据,我构建了游戏的一点实现,并且对于给定的N和M,我可以生成随机求解状态,其中每种颜色的瓷砖数都是随机的。完成后,我将网格重排10次(也许太多了),并获得了成功的途径。一旦获得成功之路,就将网格的每个状态及其移动键关联到下一个状态(即每次都更接近已解决状态)。重复很多次以生成实际的训练数据集。然后,将网格(特征)和移动键(标签)分开,并在它们上运行模型。

如果有更好的方法来训练nn(更改生成数据的方法,更改实际的数据输入),请告诉我您的想法。

因此,例如,这是我根据YouTube和其他网站上的教程建立的实际模型,这些教程主要处理图像处理和类之间的分类,因此我已经知道这不是解决问题的好方法:

设置变量

input_shape = (-1, height, width, 1) # height and width of grid playground
outputl = len(moves) # number of different moves, the output must be one of these

epoch_nb = 10
batch_size = 32

train_data_gen = data_generator(10000)
eval_data_gen = data_generator(200)

X = [] # feature set
y = [] # label set

设置训练数组

for _set in train_data_gen : 

  for feature, label in _set :
     X.append(feature)
     y.append(label)

X = np.array(X).reshape(*input_shape)
y = to_categorical(y)

建立模型

model = Sequential()
model.add(Conv2D(64, (3,3), input_shape=input_shape[1:], activation='relu'))
model.add(Activation('relu'))

# this part triggers the negative dimension ValueError
# model.add(Conv2D(64, (3,3), input_shape=input_shape[1:], activation='relu'))

model.add(Conv2D(128, kernel_size=(3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(128, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(outputl, activation='softmax'))

运行

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(X, y, batch_size=batch_size, validation_split=0.1, epochs=epoch_nb)

这里是输出:

Train on 93465 samples, validate on 10385 samples
Epoch 1/10
93465/93465 [==============================] - 10s 110us/sample - loss: 2.8523 - accuracy: 0.1491 - val_loss: 2.7131 - val_accuracy: 0.2074
Epoch 2/10
93465/93465 [==============================] - 10s 110us/sample - loss: 2.7095 - accuracy: 0.2078 - val_loss: 2.6579 - val_accuracy: 0.2203
...
...
Epoch 9/10
93465/93465 [==============================] - 10s 107us/sample - loss: 2.5634 - accuracy: 0.2466 - val_loss: 2.6467 - val_accuracy: 0.2235
Epoch 10/10
93465/93465 [==============================] - 10s 111us/sample - loss: 2.5504 - accuracy: 0.2500 - val_loss: 2.6484 - val_accuracy: 0.2216

如果您需要有关训练数据的更多详细信息,例如详细信息,等等,我可以进一步解释。

提前感谢您抽出宝贵的时间来帮助我;)

0 个答案:

没有答案
相关问题