Tensorflow获取数据(SVHN)

时间:2016-04-03 10:59:00

标签: machine-learning tensorflow

我成功安装了tensorflow,并按照MNIST数据的简易教程进行操作。 现在我想建立模型来训练SVHN数据。不幸的是,我无法在任何地方找到如何将数据输入模型。 基本上是每个模型的第一步。 数据保存在dict中,参数键为“X”,标签为“y”。 形状如下:

print traindata ['X']。shape

(32,32,3,73257)

print traindata ['y']。shape

(73257,1)

任何人都可以给我一个提示或链接如何成功地进入tensorflow吗?

谢谢

4 个答案:

答案 0 :(得分:1)

SVNH的digitStruct.mat是matlab的文件格式,所以你应该将其转换。

以下是将digitStruct.mat转换为json的代码,或者您可以使用scipy.io.loadmat

# coding: utf-8
# SVHN extracts data from the digitStruct.mat full numbers files.  The data can be downloaded
# the Street View House Number (SVHN)  web site: http://ufldl.stanford.edu/housenumbers.
#
# This is an A2iA tweak (YG -9 Jan 2014) of the script found here :
# http://blog.grimwisdom.com/python/street-view-house-numbers-svhn-and-octave
#
# The digitStruct.mat files in the full numbers tars (train.tar.gz, test.tar.gz, and extra.tar.gz)
# are only compatible with matlab.  This Python program can be run at the command line and will generate
# a json version of the dataset.
#
# Command line usage:
#       SVHN_dataextract.py [-f input] [-o output_without_extension]
#    >  python SVHN_dataextract.py -f digitStruct.mat -o digitStruct
#
# Issues:
#    The alibility to split in several files has been removed from the original
#    script.
#

import tqdm
import h5py
import optparse
from json import JSONEncoder

parser = optparse.OptionParser()
parser.add_option("-f", dest="fin", help="Matlab full number SVHN input file", default="digitStruct.mat")
parser.add_option("-o", dest="filePrefix", help="name for the json output file", default="digitStruct")
options, args = parser.parse_args()

fin = options.fin


# The DigitStructFile is just a wrapper around the h5py data.  It basically references
#    inf:              The input h5 matlab file
#    digitStructName   The h5 ref to all the file names
#    digitStructBbox   The h5 ref to all struc data
class DigitStructFile:
    def __init__(self, inf):
        self.inf = h5py.File(inf, 'r')
        self.digitStructName = self.inf['digitStruct']['name']
        self.digitStructBbox = self.inf['digitStruct']['bbox']

    # getName returns the 'name' string for for the n(th) digitStruct.
    def getName(self, n):
        return ''.join([chr(c[0]) for c in self.inf[self.digitStructName[n][0]].value])

    # bboxHelper handles the coding difference when there is exactly one bbox or an array of bbox.
    def bboxHelper(self, attr):
        if len(attr) > 1:
            attr = [self.inf[attr.value[j].item()].value[0][0] for j in range(len(attr))]
        else:
            attr = [attr.value[0][0]]
        return attr

    # getBbox returns a dict of data for the n(th) bbox.
    def getBbox(self, n):
        bbox = {}
        bb = self.digitStructBbox[n].item()
        bbox['height'] = self.bboxHelper(self.inf[bb]["height"])
        bbox['label'] = self.bboxHelper(self.inf[bb]["label"])
        bbox['left'] = self.bboxHelper(self.inf[bb]["left"])
        bbox['top'] = self.bboxHelper(self.inf[bb]["top"])
        bbox['width'] = self.bboxHelper(self.inf[bb]["width"])
        return bbox

    def getDigitStructure(self, n):
        s = self.getBbox(n)
        s['name'] = self.getName(n)
        return s

    # getAllDigitStructure returns all the digitStruct from the input file.
    def getAllDigitStructure(self):
        print('Starting get all digit structure')
        return [self.getDigitStructure(i) for i in tqdm.tqdm(range(len(self.digitStructName)))]

    # Return a restructured version of the dataset (one structure by boxed digit).
    #
    #   Return a list of such dicts :
    #      'filename' : filename of the samples
    #      'boxes' : list of such dicts (one by digit) :
    #          'label' : 1 to 9 corresponding digits. 10 for digit '0' in image.
    #          'left', 'top' : position of bounding box
    #          'width', 'height' : dimension of bounding box
    #
    # Note: We may turn this to a generator, if memory issues arise.
    def getAllDigitStructure_ByDigit(self):
        pictDat = self.getAllDigitStructure()
        result = []
        structCnt = 1
        print('Starting pack josn dict')
        for i in tqdm.tqdm(range(len(pictDat))):
            item = {'filename': pictDat[i]["name"] }
            figures = []
            for j in range(len(pictDat[i]['height'])):
                figure = dict()
                figure['height'] = pictDat[i]['height'][j]
                figure['label']  = pictDat[i]['label'][j]
                figure['left']   = pictDat[i]['left'][j]
                figure['top']    = pictDat[i]['top'][j]
                figure['width']  = pictDat[i]['width'][j]
                figures.append(figure)
            structCnt += 1
            item['boxes'] = figures
            result.append(item)
        return result


dsf = DigitStructFile(fin)
dataset = dsf.getAllDigitStructure_ByDigit()
fout = open(options.filePrefix + ".json", 'w')
fout.write(JSONEncoder(indent=True).encode(dataset))
fout.close()

之后,您应编写代​​码以将数据加载到numpy。

在我看来,你的任务不是将数据加载到TensorFlow中,而是将所有图像加载到numpy中。因此,您还应该使用PIL库将图像读取为numpy格式。

答案 1 :(得分:0)

我认为一个好主意是将您的数据重塑为以下内容:

print traindata ['X']。形状
(73257,32 * 32 * 3)=(73257,3072)

我也是tensorflow和python的新手,但我认为你可以通过numpy与np.reshape一起使用

查看以下文档: http://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html

告诉我它是否适合你:)

答案 2 :(得分:0)

TensorFlow使用以下概念:首先,定义图形;接下来,你训练一个图表;最后,你使用图表。

在定义图表时,您可以创建占位符。这就像图表的输入节点。但是,此时,这些变量在与输入数据无关的意义上是“空的”。

在火车时和测试时,您都可以通过参考此预定义的输入节点将数据“提供”到图表中。

这个概念对你来说可能是新的,我建议你学习一些教程。 TensorFlow本身有一个很好的页面叫做“Tensorflow Mechanics 101”和“基本用法”。 如果您更直观,我可以推荐YouTube频道“Dan Does Data”,他以幽默的方式探索TensorFlow概念。

如果您更喜欢喜欢示例代码的人,可以考虑this示例,其中我为MNIST制作了一个小型CNN。看看占位符“x”和“y_”,它们是您感兴趣的变量。

答案 3 :(得分:0)

我正在学习使用此数据集的课程。需要转换为灰度并保持维度,将图像数据缩放为[0,1),将10个标签更改为0,并且moveaxis使图像的索引在前。最后做了下面的事情。之后,诸如:model.fit(xTrain,yTrain, ... ) 起作用了。

    xTrain,xTest=np.mean(xTrain0,axis=2,keepdims=True),np.mean(xTest0,axis=2,keepdims=True);
xTrain/=255; xTest/=255
print(f'Min: {xTrain.min()}, Max: {xTrain.max()}')
yTrain[yTrain>9]=0; yTest[yTest>9]=0
print(f'Min: {yTrain.min()}, Max: {yTrain.max()}')
xTrain,xTest=np.moveaxis(xTrain,-1,0),np.moveaxis(xTest,-1,0)
print(xTrain.shape,xTest.shape)