分类上的张量流

时间:2017-09-29 02:47:16

标签: python tensorflow artificial-intelligence

我目前正在学习张量流。我试图使用softmax模型建立分类模型。 在程序中,我将训练数据集设置在CSV文件中两列右侧和两列右侧的两个标签上。如:

数据1,数据2,label1的,LABEL2
 234,23,1,0#234大于23,因此label1标记为1,label2标记为0  156,113,1,0   1,4,0​​,1,

它可以根据上述训练数据集对最大数量的测试数据进行分类,并将成本值收敛到接近零。

但是,我更改数据集以标记偶数,其目的在于将数据分类为偶数,模型失败,而成本则在波动。数据集如下:

数据1,数据2,label1的,LABEL2
 24,35,1,0#24是偶数,因此label1标记为1,label2标记为0  156,553,1,0   1,4,0​​,1,

我的程序错了吗?为什么它可以区分数据集中的最大数字,而偶数则失败?谢谢大家! 这是我的代码:

import tensorflow as tf
import os
import numpy as np


def next_batch(num, data, labels):
    idx = np.arange(0 , len(data))
    np.random.shuffle(idx)
    idx = idx[:num]
    data_shuffle = [data[ i] for i in idx]

    labels_shuffle = [labels[ i] for i in idx]
    return np.asarray(data_shuffle), np.asarray(labels_shuffle)

dir_path = os.path.dirname(os.path.realpath(__file__))

filename = dir_path + "/classification.csv"

x = tf.placeholder(tf.float32, [None, 2]) 
y = tf.placeholder(tf.float32, [None, 2]) 
W = tf.Variable(tf.zeros([2, 2]))
b = tf.Variable(tf.zeros([2]))

pred =tf.add( tf.matmul(x, W),b)

cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.GradientDescentOptimizer(0.1).minimize(cost)

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer())

    with open(filename) as inf:
        # Skip header
        next(inf)
        result_array = np.shape(4)
        for line in inf:

            data1, data2,label1,label2= line.strip().split(",")

            data1 = float(data1)
            data2 = float(data2)
            label1 = int(label1)
            label2 = int(label2)
            result_array = np.append(result_array, (data1,data2,label1,label2))

    result_array=result_array.reshape(1000,4)
    k=result_array[:,2:4]
    gg=result_array[:,0:2] 
    for i in range(0,3000):
        batch_xs, batch_ys = next_batch(200,gg,k)  

        h,cos=sess.run([optimizer, cost], feed_dict={x: batch_xs,y:batch_ys})
        print(cos)

    print(sess.run(pred,feed_dict={x:[[5,2],[4,9],[4,3],[5,2],[3,6],[30,21],[32,20],[3,4]]})) #testing data

0 个答案:

没有答案