在Spyder的ipython控制台中使用mxnet进行训练无法正常工作

时间:2017-11-11 17:05:29

标签: ipython spyder mxnet

我在Spyder IDE中运行mxnet的mnist示例。培训没有进展,就像学习率是0一样(见下面的输出)。

如果我使用pythonipython在控制台上运行相同的文件,它会按预期工作(在第二个纪元中看到改进)。

如果我使用Spyder的普通Python控制台运行该文件,它也可以。 但是如果我使用ipython控制台运行该文件,我会得到如下所示的输出。

我是Python的新手。有什么想法吗?

源代码

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

#%%
import mxnet as mx

#%% load mnist

mnist = mx.test_utils.get_mnist()

#%% define data iterators

batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

#%% create input variable

data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)

#%% Multilayer Perceptron with softmax

# The first fully-connected layer and the corresponding activation function
fc1  = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")

# The second fully-connected layer and the corresponding activation function
fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")

# MNIST has 10 classes
fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

#%% Training
import logging
logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout
# create a trainable module on CPU
mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())
mlp_model.fit(train_iter,  # train data
              eval_data=val_iter,  # validation data
              optimizer='sgd',  # use SGD to train
              optimizer_params={'learning_rate':0.1},  # use fixed learning rate
              eval_metric='acc',  # report accuracy during training
              batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
              num_epoch=10)  # train for at most 10 dataset passes

Spyder中ipython控制台的输出

runfile('/home/thomas/workspace/clover-python/mnist.py', wdir='/home/thomas/workspace/clover-python')
INFO:root:train-labels-idx1-ubyte.gz exists, skipping download
INFO:root:train-images-idx3-ubyte.gz exists, skipping download
INFO:root:t10k-labels-idx1-ubyte.gz exists, skipping download
INFO:root:t10k-images-idx3-ubyte.gz exists, skipping download
INFO:root:Epoch[0] Batch [100]  Speed: 1304.05 samples/sec      accuracy=0.097921
INFO:root:Epoch[0] Batch [200]  Speed: 1059.19 samples/sec      accuracy=0.098200
INFO:root:Epoch[0] Batch [300]  Speed: 1178.64 samples/sec      accuracy=0.099600
INFO:root:Epoch[0] Batch [400]  Speed: 1292.71 samples/sec      accuracy=0.098900
INFO:root:Epoch[0] Batch [500]  Speed: 1394.21 samples/sec      accuracy=0.096500
INFO:root:Epoch[0] Train-accuracy=0.101212
INFO:root:Epoch[0] Time cost=47.798
INFO:root:Epoch[0] Validation-accuracy=0.098000
INFO:root:Epoch[1] Batch [100]  Speed: 1247.47 samples/sec      accuracy=0.097921
INFO:root:Epoch[1] Batch [200]  Speed: 1673.79 samples/sec      accuracy=0.098200
INFO:root:Epoch[1] Batch [300]  Speed: 1283.91 samples/sec      accuracy=0.099600
INFO:root:Epoch[1] Batch [400]  Speed: 1247.79 samples/sec      accuracy=0.098900
INFO:root:Epoch[1] Batch [500]  Speed: 1371.93 samples/sec      accuracy=0.096500
INFO:root:Epoch[1] Train-accuracy=0.101212
INFO:root:Epoch[1] Time cost=44.201
INFO:root:Epoch[1] Validation-accuracy=0.098000
INFO:root:Epoch[2] Batch [100]  Speed: 1387.72 samples/sec      accuracy=0.097921
INFO:root:Epoch[2] Batch [200]  Speed: 1196.37 samples/sec      accuracy=0.098200
INFO:root:Epoch[2] Batch [300]  Speed: 1220.44 samples/sec      accuracy=0.099600
INFO:root:Epoch[2] Batch [400]  Speed: 1387.75 samples/sec      accuracy=0.098900
INFO:root:Epoch[2] Batch [500]  Speed: 1279.58 samples/sec      accuracy=0.096500
INFO:root:Epoch[2] Train-accuracy=0.101212
INFO:root:Epoch[2] Time cost=46.929
INFO:root:Epoch[2] Validation-accuracy=0.098000
INFO:root:Epoch[3] Batch [100]  Speed: 1266.24 samples/sec      accuracy=0.097921

and so on…

2 个答案:

答案 0 :(得分:1)

根据Eric的评论,让我们验证您是否在每个环境中运行相同版本的mxnet。您可以输入

import mxnet
print(mxnet.__version__)

答案 1 :(得分:0)

您是否安装了多个版本的MXNet?我能想到的一件事是验证struct PostModelSpeedRunModel { var id = "" var international = "" var abbreviation = "" var runsLink = "" var uri = "" } extension PostModelSpeedRunModel: Mappable { init?(map: Map) { } mutating func mapping(map: Map) { id <- map["id"] international <- map["international"] abbreviation <- map["abbreviation"] uri <- map["logo"] var links: [Link]? links <- map["links"] if let uri = links?.first(where: {$0.rel == "runs"})?.uri { runsLink = uri } } } struct Link { var rel = "" var uri = "" } extension Link: Mappable { init?(map: Map) { } mutating func mapping(map: Map) { rel <- map["rel"] uri <- map["uri"] } } 控制台和ipython控制台是否使用相同版本的python和MXNet。在ipython中,您可以键入python以查看您正在使用的MXNet版本。