简单的rnn网络大量使用内存

时间:2017-07-09 08:18:00

标签: python windows memory-management tensorflow tflearn

我正在尝试从数据集中学习,其中最长的序列长度超过6000个形状=(?,?,4),所以我将它们填充为:

for data in y_data:
        count = max_len - len(data)
        for i in range(count):
            data.append(np.zeros(4))

我的模型看起来像这样:

def get_net(sequence_length):
    net = tflearn.input_data(shape=[None, length, 4])
    net = tflearn.lstm(net, n_units=128, dropout=0.8)
    net =  tflearn.fully_connected(net, 2, activation='linear')
    net = tflearn.regression(net, optimizer="adam", loss='mean_square')
    return tflearn.DNN(net, tensorboard_verbose=0)

当我尝试从这些数据中学习时,计算需要很长时间。在它填满3Gb的内存后,它开始构建模型。在更长的时间之后,在大约10Gb处发生奇怪的跳跃并且模型可以开始计算。

我真的不认为这应该发生。我犯了错误吗?

Os:Windows 10 Tensorflow:1.1.0 Python:3.5.0

修改

# global train- and test data
x_data = []
y_data = []
max_len = 0

# The csv files are formatted like this: "%i,%i,%i,%i;%f,%f"
for csv_file in files[:10]: 

    # Open csv-files
    with open(os.path.join(data_dir, csv_file), 'r') as f:
        # local train- and test data
        csv_x_data = []
        csv_y_data = []
        line = f.readline()
        if line:
            while True:
                data = line.split(';')
                # append each rows data to the local train- and test data 
                csv_x_data.append([int(x) for x in data[0].split(',')])
                csv_y_data.append([float(x) for x in data[1].split(',')])
                line = f.readline()
                if not line:
                    break

        # Appending values to the  
        x_data.append(csv_x_data)
        y_data.append(csv_y_data)

        # Keep track of max_length for padding
        csv_len = len(csv_x_data)
        if csv_len > max_len:
            max_len = csv_len

# Padding the sequences
for data in x_data:
    count = max_len - len(data)
    for i in range(count):
        data.append(np.zeros(4))

print("data loaded")

# Not sure if I have to format the labels in the same way as the training data
# for data in y_data:
#     count = max_len - len(data)
#     for i in range(count):
#         data.append(np.zeros(4))

model = get_net(max_len)
print("model loaded")

x_data = np.array(x_data)
y_data = np.array(y_data)
model.fit(x_data, y_data, validation_set=0.1, batch_size=5) # batch_size for testing proposes so small

我最初开始时有500个文件,但我必须意识到,内存问题非常严重,甚至10个文件太多了

Example csv-files

0 个答案:

没有答案