tf.gather_nd多次使用时确实很慢

时间:2018-07-02 03:18:35

标签: tensorflow slice

我想要张量流中的损失函数,它是许多元素的复杂组合。例如,这段代码:

import tensorflow as tf
import numpy as np
import time

input_layer = tf.placeholder(tf.float64, shape=[64,4])
output_layer = input_layer + 0.5*tf.tanh(tf.Variable(tf.random_uniform(shape=[64,4],\
                                                       minval=-1,maxval=1,dtype=tf.float64)))

# random_combination is 2-d numpy array of the form:
# [[32, 34, 23, 56],[23,54,33,21],...]
random_combination = np.random.randint(64, size=(210000000, 4))

# a collector to collect the values 
collector=[]

print('start looping')   
print(time.asctime(time.localtime(time.time())))

# loop through random_combination and pick the elements of output_layer
for i in range(len(random_combination)):
    [i,j,k,l] = [random_combination[i][0],random_combination[i][1],\
                 random_combination[i][2],random_combination[i][3]]

    # pick the needed element from output_layer
    f1 = tf.gather_nd(output_layer,[i,0])
    f2 = tf.gather_nd(output_layer,[i,2])
    f3 = tf.gather_nd(output_layer,[i,3])
    f4 = tf.gather_nd(output_layer,[i,4])

    tf1 = f1+1
    tf2 = f2+1
    tf3 = f3+1
    tf4 = f4+1
    collector.append(0.3*tf.abs(f1*f2*tf3*tf4-tf1*tf2*f3*f4))

print('end looping')   
print(time.asctime(time.localtime(time.time())))

# loss function
loss = tf.add_n(collector)

这在我的计算机上大约需要50分钟。 我的问题是,这是在张量流中进行编码的正确方法吗? 还是有一种更省时的方法来索引元素?

0 个答案:

没有答案
相关问题