Tensorflow NN准确性/预测不变

时间:2019-06-01 16:50:41

标签: python tensorflow

尝试使用Tensorflow创建NN来预测是否进行了NBA投篮,但是该模型似乎只是预测每次击球都会失手,而且准确性不会因时而异。

我不确定是否是由于数据本身或梯度下降函数引起的。它与虹膜数据集配合使用,无法确定数据更改时发生了什么变化。

import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd

RANDOM_SEED = 42
tf.set_random_seed(RANDOM_SEED)


def init_weights(shape):
    """ Weight initialization """
    weights = tf.random_normal(shape, stddev=0.1)
    return tf.Variable(weights)

def forwardprop(X, w_1, w_2):
    """
    Forward-propagation.
    IMPORTANT: yhat is not softmax since TensorFlow's softmax_cross_entropy_with_logits() does that internally.
    """
    h    = tf.nn.sigmoid(tf.matmul(X, w_1))  # The \sigma function
    yhat = tf.matmul(h, w_2)  # The \varphi function
    return yhat

def get_iris_data():
    """ Read the iris data set and split them into training and test sets """
    my_data = pd.read_csv("shot_logs.csv").sample(2000)
    my_data["GAME_CLOCK"] = pd.to_numeric(my_data["GAME_CLOCK"].str.split(":").str[0]) * 60 + pd.to_numeric(
        my_data["GAME_CLOCK"].str.split(":").str[1])
    my_data["LOCATION"] = my_data["LOCATION"].map({
        "A":0,
        "H":1
    })

    data   = my_data[["LOCATION", "SHOT_NUMBER", "PERIOD", "GAME_CLOCK", "SHOT_CLOCK", "DRIBBLES", "TOUCH_TIME",
                   "SHOT_DIST", "CLOSE_DEF_DIST"]]
    target = my_data["FGM"]

    # Prepend the column of 1s for bias
    N, M  = data.shape
    all_X = np.ones((N, M + 1))
    all_X[:, 1:] = data

    # Convert into one-hot vectors
    num_labels = len(np.unique(target))
    print(num_labels)
    print(np.eye(num_labels))
    all_Y = np.eye(num_labels)[target]  # One liner trick!
    return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED)

def main():
    train_X, test_X, train_y, test_y = get_iris_data()

    # Layer's sizes
    x_size = train_X.shape[1]   # Number of input nodes: 4 features and 1 bias
    h_size = 256                # Number of hidden nodes
    y_size = train_y.shape[1]   # Number of outcomes (3 iris flowers)
    print(y_size)

    # Symbols
    X = tf.placeholder("float", shape=[None, x_size])
    y = tf.placeholder("float", shape=[None, y_size])

    # Weight initializations
    w_1 = init_weights((x_size, h_size))
    w_2 = init_weights((h_size, y_size))

    # Forward propagation
    yhat    = forwardprop(X, w_1, w_2)
    predict = tf.argmax(yhat, axis=1)

    # Backward propagation
    cost    = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=yhat))
    updates = tf.train.GradientDescentOptimizer(0.1).minimize(cost)

    # Run SGD
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    for epoch in range(100):
        # Train with each example
        for i in range(len(train_X)):
            sess.run(updates, feed_dict={X: train_X[i: i + 1], y: train_y[i: i + 1]})

        train_accuracy = np.mean(np.argmax(train_y, axis=1) ==
                                 sess.run(predict, feed_dict={X: train_X, y: train_y}))
        test_accuracy  = np.mean(np.argmax(test_y, axis=1) ==
                                 sess.run(predict, feed_dict={X: test_X, y: test_y}))

        print("Epoch = %d, train accuracy = %.2f%%, test accuracy = %.2f%%"
              % (epoch + 1, 100. * train_accuracy, 100. * test_accuracy))

    sess.close()

if __name__ == '__main__':
    main()

这是输出的样子:

Epoch = 1, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 2, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 3, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 4, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 5, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 6, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 7, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 8, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 9, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 10, train accuracy = 54.85%, test accuracy = 54.55%
Epoch = 11, train accuracy = 54.85%, test accuracy = 54.55%

0 个答案:

没有答案
相关问题