如何修复处理CIFAR-10数据集的维度(值)错误,构建CNN?

时间:2019-04-28 20:22:35

标签: python image-processing machine-learning keras conv-neural-network

研究cifar-10数据集以构建CNN并评估损失和准确性。我想做的是使用keras将数据集分为训练和测试数据,然后训练模型。 但是在最后一步,它给了我尺寸错误,而我无能为力。请帮忙!

代码如下:

import numpy as np
import pickle
import tensorflow as tf
import os
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import sklearn

path ='cifar-10-batches-py'

def load_cfar10_batch(path):
    with open(path + '/data_batch_1', mode='rb') as file:
        batch = pickle.load(file, encoding='latin1')   
    features = batch['data']
    labels = batch['labels']
    return features, labels


x = features.reshape((len(features), 3, 32, 32)).transpose(0, 2, 3, 1)
x.shape

y = labels
def one_hot_encode(y):
    encoded = np.zeros((len(y), 10))

    for index, val in enumerate(y):
        encoded[index][val] = 1

    return encoded


def normalize(x):
    x_norm = x/255
    return x_norm

from sklearn import preprocessing
scaler = preprocessing.StandardScaler()
scaled_df = scaler.fit_transform(features)
scaled_df = scaled_df.reshape(10000,3,32,32).transpose(0,2,3,1)
plt.imshow(scaled_df[9999])

def _preprocess_and_save(normalize_and_standardize, one_hot_encode, features, labels, filename):
    features = normalize(x)    
    labels = one_hot_encode(y) 

    pickle.dump((features, labels), open(filename, 'wb')) 


features, labels = load_cfar10_batch(path)


from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)


def preprocess_and_save_data(path, normalize, one_hot_encode):
    #check where the code for _preprocess_and_save is
    _preprocess_and_save(normalize, one_hot_encode, np.array(x_test), np.array(y_test), 'preprocess_test.p')
    _preprocess_and_save(normalize, one_hot_encode, np.array(x_train), np.array(y_train), 'preprocess_training.p')


preprocess_and_save_data(path, normalize, one_hot_encode)

x_test, y_test = pickle.load(open('preprocess_test.p', mode='rb'))
y_train, y_train = pickle.load(open('preprocess_training.p', mode='rb'))

def tf_reset():
    try:
        sess.close()
    except:
        pass
    tf.reset_default_graph()
    return tf.Session()

sess = tf_reset()

x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='input_x')
y =  tf.placeholder(tf.float32, shape=(None, 10), name='output_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')

def conv_net(x, keep_prob):
    #x = tf.reshape(x,[-1,32,32,3])
    conv1_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 3, 64], mean=0, stddev=0.08))
    conv2_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 128], mean=0, stddev=0.08))
    conv3_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 128, 256], mean=0, stddev=0.08))
    conv4_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 256, 512], mean=0, stddev=0.08))


 #Layer1
    conv1 = tf.nn.conv2d(x, conv1_filter, strides=[1,1,1,1], padding='SAME')
    conv1 = tf.nn.relu(conv1)
    conv1_pool = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    conv1_bn = tf.layers.batch_normalization(conv1_pool)

 #Layer2
    conv2 = tf.nn.conv2d(conv1_bn, conv2_filter, strides=[1,1,1,1], padding='SAME')
    conv2 = tf.nn.relu(conv2)
    conv2_pool = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')    
    conv2_bn = tf.layers.batch_normalization(conv2_pool)

 #Layer 3
    conv3 = tf.nn.conv2d(conv2_bn, conv3_filter, strides=[1,1,1,1], padding='SAME')
    conv3 = tf.nn.relu(conv3)
    conv3_pool = tf.nn.max_pool(conv3, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')  
    conv3_bn = tf.layers.batch_normalization(conv3_pool)

 #Layer 4
    conv4 = tf.nn.conv2d(conv3_bn, conv4_filter, strides=[1,1,1,1], padding='SAME')
    conv4 = tf.nn.relu(conv4)
    conv4_pool = tf.nn.max_pool(conv4, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    conv4_bn = tf.layers.batch_normalization(conv4_pool)


    flat = tf.contrib.layers.flatten(conv4_bn)  


    full1 = tf.contrib.layers.fully_connected(inputs=flat, num_outputs=128, activation_fn=tf.nn.relu)
    full1 = tf.nn.dropout(full1, keep_prob)
    full1 = tf.layers.batch_normalization(full1)


    full2 = tf.contrib.layers.fully_connected(inputs=full1, num_outputs=256, activation_fn=tf.nn.relu)
    full2 = tf.nn.dropout(full2, keep_prob)
    full2 = tf.layers.batch_normalization(full2)


    full3 = tf.contrib.layers.fully_connected(inputs=full2, num_outputs=512, activation_fn=tf.nn.relu)
    full3 = tf.nn.dropout(full3, keep_prob)
    full3 = tf.layers.batch_normalization(full3)    


    full4 = tf.contrib.layers.fully_connected(inputs=full3, num_outputs=1024, activation_fn=tf.nn.relu)
    full4 = tf.nn.dropout(full4, keep_prob)
    full4 = tf.layers.batch_normalization(full4)        


    out = tf.contrib.layers.fully_connected(inputs=full3, num_outputs=10, activation_fn=None)

    return out


iterations = 101
batch_size = 128
keep_probability = 0.7
learning_rate = 0.001

logits = conv_net(x, keep_prob)

# Loss and Optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')

def train_neural_network(session, optimizer, keep_probability, feature_batch, label_batch):
    session.run(optimizer, 
                feed_dict={
                    x: feature_batch,
                    y: label_batch,
                    keep_prob: keep_probability
                })

def print_stats(sess, feature_batch, label_batch, cost, accuracy):
    loss = sess.run(cost, 
                    feed_dict={
                        x: feature_batch,
                        y: label_batch,
                        keep_prob: 1.
                    })
    valid_acc = sess.run(accuracy, 
                         feed_dict={
                             x: x_train,
                             y: y_train,
                             keep_prob: 1.
                         })

    print('Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(loss, valid_acc))

def batch_features_labels(features, labels, batch_size):
    """
    Split features and labels 
    """

    for start in range(0, len(features), batch_size):
        end = min(start + batch_size, len(features))
        yield features[start:end], labels[start:end]

def load_preprocess_training(batch_size):
    """
    Load the Preprocessed Training data and return them in batches of <batch_size> or less
    """
    features = features.reshape((len(features), 3, 32, 32)).transpose(0, 2, 3, 1)
    filename = 'preprocess_training.p'
    features, labels = pickle.load(open(filename, mode='rb'))

    # Return the training data in batches of size <batch_size> or less
    return batch_features_labels(features, labels, batch_size)


print('Training...')
with tf.Session() as sess:
    # Initializing the variables
    sess.run(tf.global_variables_initializer())

    # Training cycle
    for i in range(iterations):
        for batch_features, batch_labels in load_preprocess_training(batch_size):
                train_neural_network(sess, optimizer, keep_probability, batch_features, batch_labels)

                if i % 10 == 0:
                    print('Iterations {}, CIFAR-10 Batch {}:  '.format(i, 1), end='')
                    print_stats(sess, batch_features, batch_labels, cost, accuracy)

ValueError:无法为张量为'(?,32,32,3)'的张量'input_x:0'输入形状(8000,3072)的值

1 个答案:

答案 0 :(得分:0)

问题位于此处:

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".StartActivity">

    <ScrollView
        android:layout_width="match_parent"
        android:layout_height="wrap_content"

        android:layout_weight="1"
        android:fillViewport="true">

        <LinearLayout
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:background="@drawable/pg5"
            android:orientation="vertical"
            android:weightSum="5"
            tools:layout_editor_absoluteX="8dp"
            tools:layout_editor_absoluteY="8dp">

            <LinearLayout
                android:id="@+id/gb_text"
                android:layout_width="match_parent"
                android:layout_height="0dp"
                android:layout_weight="2.55"
                android:gravity="center|top"
                android:orientation="vertical"
                android:weightSum="3">


                <LinearLayout
                    android:layout_width="match_parent"
                    android:layout_height="0dp"
                    android:layout_weight="1.2"
                    android:orientation="vertical">  
                    <LinearLayout
                        android:layout_width="match_parent"
                        android:layout_height="match_parent"
                        android:orientation="horizontal"
                        android:weightSum="3">

                        <LinearLayout
                            android:layout_width="0dp"
                            android:layout_height="match_parent"
                            android:layout_weight="0.5"
                            android:orientation="horizontal"></LinearLayout>

                        <LinearLayout
                            android:layout_width="0dp"
                            android:layout_height="match_parent"
                            android:layout_weight="2"
                            android:gravity="center|top"
                            android:orientation="horizontal">

                            <LinearLayout
                                android:layout_width="match_parent"
                                android:layout_height="match_parent"
                                android:layout_weight="1"
                                android:gravity="center|top"
                                android:orientation="vertical">

                                <TextView
                                    android:id="@+id/textView"
                                    android:layout_width="wrap_content"
                                    android:layout_height="wrap_content"
                                    android:text="Duplicate Files Remover"
                                    android:textColor="@android:color/white"
                                    android:textSize="@dimen/DuplicateFileRemovertext" />
                            </LinearLayout>

                        </LinearLayout>

                        <LinearLayout
                            android:layout_width="0dp"
                            android:layout_height="match_parent"
                            android:layout_weight="0.5"
                            android:orientation="horizontal"></LinearLayout>
                    </LinearLayout>
                </LinearLayout>

                <com.intrusoft.sectionedrecyclerviewapp.CircularProgressBar
                    android:id="@+id/circularProgress"
                    android:layout_width="@dimen/circularprogresswidth"
                    android:layout_height="@dimen/circularprogressheight"
                    android:layout_centerHorizontal="true"
                    android:layout_marginTop="@dimen/circularprogresstopmargin" />

                <TextView
                    android:id="@+id/gb_textview"
                    android:layout_width="match_parent"
                    android:layout_height="wrap_content"
                    android:layout_marginTop="2dp"
                    android:gravity="center"
                    android:text="TextView"
                    android:textColor="@android:color/white"
                    android:textSize="@dimen/totalavailablesize" />


            </LinearLayout>

            <LinearLayout
                android:layout_width="match_parent"
                android:layout_height="0dp"
                android:layout_weight="2.45"
                android:orientation="vertical"
                android:weightSum="5">

                <LinearLayout
                    android:id="@+id/scanimages_layout"
                    android:layout_width="match_parent"
                    android:layout_height="0dp"
                    android:layout_margin="3dp"
                    android:layout_weight="1"
                    android:weightSum="2">

                    <android.support.v7.widget.CardView
                        android:layout_width="match_parent"
                        android:layout_height="match_parent"
                        android:layout_marginLeft="10dp"
                        android:layout_marginRight="10dp"
                        app:cardCornerRadius="12dp">

                        <LinearLayout
                            android:id="@+id/piccardlayout"
                            android:layout_width="match_parent"
                            android:layout_height="match_parent"
                            android:background="@color/bluelayout"
                            android:orientation="horizontal"
                            android:weightSum="2">

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="0.5"
                                android:orientation="horizontal"
                                android:paddingLeft="8dp">

                                <com.makeramen.roundedimageview.RoundedImageView
                                    android:id="@+id/imageView2"
                                    android:layout_width="match_parent"
                                    android:layout_height="match_parent"
                                    android:layout_alignParentStart="true"
                                    android:layout_alignParentLeft="true"
                                    android:layout_alignParentTop="true"
                                    android:layout_alignParentEnd="true"
                                    android:layout_alignParentRight="true"
                                    android:layout_alignParentBottom="true"
                                    android:layout_marginStart="0dp"
                                    android:layout_marginLeft="0dp"
                                    android:layout_marginTop="0dp"
                                    android:layout_marginEnd="0dp"
                                    android:layout_marginRight="0dp"
                                    android:layout_marginBottom="0dp"
                                    android:layout_weight="1"
                                    android:adjustViewBounds="true"
                                    android:background="@drawable/picicon"
                                    android:scaleType="fitXY"
                                    app:riv_corner_radius="12dip"
                                    app:riv_mutate_background="true"
                                    app:riv_oval="false"
                                    app:riv_tile_mode="clamp" />
                            </LinearLayout>

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="1.5"
                                android:gravity="center"
                                android:orientation="horizontal">
                                <TextView
                                    android:id="@+id/textView2"
                                    android:layout_width="0dp"
                                    android:layout_height="wrap_content"
                                    android:layout_weight="0.5"
                                    android:paddingLeft="8dp"
                                    android:text="  Pictures"
                                    android:textColor="@android:color/white"
                                    android:textSize="@dimen/Scanpicturestext" />

                            </LinearLayout>
                        </LinearLayout>
                    </android.support.v7.widget.CardView>


                </LinearLayout>

                <LinearLayout
                    android:id="@+id/scanaudio_layout"
                    android:layout_width="match_parent"
                    android:layout_height="0dp"

                    android:layout_margin="3dp"
                    android:layout_weight="1"
                    android:weightSum="2">

                    <android.support.v7.widget.CardView
                        android:layout_width="match_parent"
                        android:layout_height="match_parent"
                        android:layout_marginLeft="10dp"
                        android:layout_marginRight="10dp"
                        app:cardCornerRadius="12dp">

                        <LinearLayout
                            android:id="@+id/audiocardlayout"
                            android:layout_width="match_parent"
                            android:layout_height="match_parent"
                            android:background="@color/purplelayout"
                            android:orientation="horizontal"
                            android:weightSum="2">

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="0.5"
                                android:orientation="horizontal"
                                android:paddingLeft="8dp">

                                <com.makeramen.roundedimageview.RoundedImageView
                                    android:id="@+id/imageView3"
                                    android:layout_width="match_parent"
                                    android:layout_height="match_parent"
                                    android:layout_alignParentStart="true"
                                    android:layout_alignParentLeft="true"
                                    android:layout_alignParentTop="true"
                                    android:layout_alignParentEnd="true"
                                    android:layout_alignParentRight="true"
                                    android:layout_alignParentBottom="true"
                                    android:layout_margin="1dp"
                                    android:layout_marginStart="0dp"
                                    android:layout_marginLeft="0dp"
                                    android:layout_marginTop="0dp"
                                    android:layout_marginEnd="0dp"
                                    android:layout_marginRight="0dp"
                                    android:layout_marginBottom="0dp"
                                    android:layout_weight="1"
                                    android:adjustViewBounds="true"
                                    android:background="@drawable/playicon"
                                    android:scaleType="fitXY"
                                    app:riv_corner_radius="12dip"
                                    app:riv_mutate_background="true"
                                    app:riv_oval="false"
                                    app:riv_tile_mode="clamp" />
                            </LinearLayout>

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="1.5"
                                android:gravity="center"
                                android:orientation="horizontal">

                                <TextView
                                    android:id="@+id/textView3"
                                    android:layout_width="0dp"
                                    android:layout_height="wrap_content"
                                    android:layout_weight="0.5"
                                    android:paddingLeft="8dp"
                                    android:text=" Audios"
                                    android:textColor="@android:color/white"
                                    android:textSize="@dimen/Scanaudiostext" />

                            </LinearLayout>
                        </LinearLayout>
                    </android.support.v7.widget.CardView>

                </LinearLayout>

                <LinearLayout
                    android:id="@+id/scanvideos_layout"
                    android:layout_width="match_parent"

                    android:layout_height="0dp"
                    android:layout_margin="3dp"
                    android:layout_weight="1"
                    android:weightSum="2">

                    <android.support.v7.widget.CardView
                        android:layout_width="match_parent"
                        android:layout_height="match_parent"
                        android:layout_marginLeft="10dp"
                        android:layout_marginRight="10dp"
                        app:cardCornerRadius="12dp">

                        <LinearLayout
                            android:id="@+id/videocardlayout"
                            android:layout_width="match_parent"
                            android:layout_height="match_parent"
                            android:background="@color/redlayout"
                            android:orientation="horizontal"
                            android:weightSum="2">

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="0.5"
                                android:orientation="horizontal"
                                android:paddingLeft="8dp">

                                <com.makeramen.roundedimageview.RoundedImageView
                                    android:id="@+id/imageView4"
                                    android:layout_width="match_parent"
                                    android:layout_height="match_parent"
                                    android:layout_alignParentStart="true"
                                    android:layout_alignParentLeft="true"
                                    android:layout_alignParentTop="true"
                                    android:layout_alignParentEnd="true"
                                    android:layout_alignParentRight="true"
                                    android:layout_alignParentBottom="true"
                                    android:layout_margin="1dp"
                                    android:layout_marginStart="0dp"
                                    android:layout_marginLeft="0dp"
                                    android:layout_marginTop="0dp"
                                    android:layout_marginEnd="0dp"
                                    android:layout_marginRight="0dp"
                                    android:layout_marginBottom="0dp"
                                    android:layout_weight="1"
                                    android:adjustViewBounds="true"
                                    android:background="@drawable/vidicon"
                                    android:scaleType="fitCenter"
                                    app:riv_corner_radius="12dip"
                                    app:riv_mutate_background="true"
                                    app:riv_oval="false"
                                    app:riv_tile_mode="clamp" />

                            </LinearLayout>

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="1.5"
                                android:gravity="center"
                                android:orientation="horizontal">

                                <TextView
                                    android:id="@+id/textView4"
                                    android:layout_width="0dp"
                                    android:layout_height="wrap_content"
                                    android:layout_weight="0.5"
                                    android:paddingLeft="8dp"
                                    android:text="Videos"
                                    android:textColor="@android:color/white"
                                    android:textSize="@dimen/Scanvideostext" />

                            </LinearLayout>
                        </LinearLayout>
                    </android.support.v7.widget.CardView>

                </LinearLayout>

                <LinearLayout
                    android:id="@+id/scandocs_layout"
                    android:layout_width="match_parent"
                    android:layout_height="0dp"
                    android:layout_margin="3dp"
                    android:layout_weight="1"
                    android:weightSum="2">

                    <android.support.v7.widget.CardView
                        android:layout_width="match_parent"
                        android:layout_height="match_parent"
                        android:layout_marginLeft="10dp"
                        android:layout_marginRight="10dp"
                        app:cardCornerRadius="12dp">

                        <LinearLayout
                            android:id="@+id/doccardlayout"
                            android:layout_width="match_parent"
                            android:layout_height="match_parent"
                            android:background="@color/greenlayout"
                            android:orientation="horizontal">

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="0.5"
                                android:orientation="horizontal"
                                android:paddingLeft="8dp">

                                <com.makeramen.roundedimageview.RoundedImageView
                                    android:id="@+id/imageView5"
                                    android:layout_width="match_parent"
                                    android:layout_height="match_parent"
                                    android:layout_alignParentStart="true"
                                    android:layout_alignParentLeft="true"
                                    android:layout_alignParentTop="true"
                                    android:layout_alignParentEnd="true"
                                    android:layout_alignParentRight="true"
                                    android:layout_alignParentBottom="true"
                                    android:layout_margin="1dp"
                                    android:layout_marginStart="0dp"
                                    android:layout_marginLeft="0dp"
                                    android:layout_marginTop="0dp"
                                    android:layout_marginEnd="0dp"
                                    android:layout_marginRight="0dp"
                                    android:layout_marginBottom="0dp"
                                    android:layout_weight="1"
                                    android:adjustViewBounds="true"
                                    android:background="@drawable/docxicon"
                                    android:scaleType="fitXY"
                                    app:riv_corner_radius="12dip"
                                    app:riv_mutate_background="true"
                                    app:riv_oval="false"
                                    app:riv_tile_mode="clamp" />

                            </LinearLayout>

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="match_parent"
                                android:layout_weight="1.5"
                                android:gravity="center"
                                android:orientation="horizontal">

                                <TextView
                                    android:id="@+id/textView5"
                                    android:layout_width="0dp"
                                    android:layout_height="wrap_content"
                                    android:layout_weight="0.5"
                                    android:paddingLeft="8dp"
                                    android:text="Docs"
                                    android:textColor="@android:color/white"
                                    android:textSize="@dimen/Scandocumentstext" />

                            </LinearLayout>
                        </LinearLayout>
                    </android.support.v7.widget.CardView>

                </LinearLayout>

            </LinearLayout>
        </LinearLayout>
    </ScrollView>
</LinearLayout>

您应该将要素中的项目的形状从3072更改为[32,32,3]

祝你好运

相关问题