如何在TensorFlow中打印SparseTensor内容?

时间:2018-03-08 21:00:50

标签: python tensorflow

为了快速调试,我试图打印出刚刚初始化的SparseTensor。

内置打印功能只是说它是一个SparseTensor对象,而tf.Print()给出了一个错误。错误语句会打印对象的内容,但不会以显示实际条目的方式打印(除非它告诉我它是空的,还有一些:0s我不喜欢'知道的重要性。

rows = tf.Print(rows, [rows])

TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("SparseTensor/indices:0", shape=(6, 2), dtype=int64), values=Tensor("SparseTensor/values:0", shape=(6,), dtype=float32), dense_shape=Tensor("SparseTensor/dense_shape:0", shape=(2,), dtype=int64)). Consider casting elements to a supported type.

2 个答案:

答案 0 :(得分:6)

方式0:运行SparseTensor并打印结果

运行图形(在本例中只是SparseTensor对象)返回一个SparseTensorValue对象,该对象的打印格式与用于初始化SparseTensor的调用相同,这最终是我想要的。

with tf.Session() as sess:
  rows = sess.run(rows)
  print(rows)

方式1:转换为密集矩阵后使用打印

要使用Print功能,我可以在我的情况下转换为密集矩阵。但是Print仅在您运行图表时执行:

rows = tf.sparse_tensor_to_dense(rows)
rows = tf.Print(rows, [rows], summarize=100)
with tf.Session() as sess:
  sess.run(rows)

注意“汇总” - 默认设置只打印出零,因为它获得了以密集形式表示的稀疏矩阵的前几个条目!

方式2:使用tf.test.TestCase

我发现TestCase.evaluate方法给了我一种我想要的漂亮格式,与上面的Way 0相同:

print(str(self.evaluate(rows)))

输出,例如:

SparseTensorValue(indices=array([[1, 2],
   [1, 7],
   [1, 8],
   [2, 2],
   [3, 4],
   [3, 5]]), values=array([1., 1., 1., 1., 1., 1.], dtype=float32), dense_shape=array([4, 9]))

答案 1 :(得分:1)

您看到此错误是因为SparseTensor并不是真正的张量,它是一个包含3个密集张量的MetaTensor。

尝试在SparseTensor上使用print(),您会看到内部详细信息:

indices=Tensor(…), values=Tensor(…), dense_shape=Tensor(…))

您可以使用tf.Print打印任何这些“内部”张量。例如,tf.Print(my_sparse_tensor.values,[my_sparse_tensor.values])将成功。

SparseTensor文档描述了内部数据结构:

https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor

TensorFlow将稀疏张量表示为三个单独的密集张量:索引,值和density_shape。在Python中,这三个张量被收集到SparseTensor类中以便于使用。如果您有单独的索引,值和density_shape张量,请在将它们传递到下面的ops之前,将它们包装在SparseTensor对象中。

具体来说,稀疏张量SparseTensor(索引,值,dense_shape)包含以下组件,其中N和ndims分别是SparseTensor中的值数和维数:

indices:密度为[N,ndims]的2-D int64张量,它指定稀疏张量中包含非零值(元素为零索引)的元素的索引。例如,indexs = [[1,3],[2,4]]指定索引为[1,3]和[2,4]的元素具有非零值。

values:任何类型和密实形状[N]的一维张量,它为索引中的每个元素提供值。例如,给定索引= [[1,3],[2,4]],参数值= [18,3.6]指定稀疏张量的元素[1,3]的值为18,而元素[ [2,4]的张量的值为3.6。

dense_shape:density_shape [ndims]的一维int64张量,它指定稀疏张量的density_shape。获取一个列表,该列表指示每个维中的元素数。例如,density_shape = [3,6]指定二维3x6张量,density_shape = [2,3,4]指定三维2x3x4张量,density_shape = [9]指定包含9个元素的一维张量

相应的密集张量满足:

dense.shape = dense_shape
dense[tuple(indices[i])] = values[i]

按照惯例,索引应按行优先顺序(或在元组索引[i]上按字典顺序等效)排序。构造SparseTensor对象时,不会强制执行此操作,但是大多数操作都假定正确的顺序。如果稀疏张量st的顺序错误,则可以通过调用tf.sparse_reorder(st)获得固定版本。

例如:稀疏张量

SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])

代表密集的张量:

[[1, 0, 0, 0]
 [0, 0, 2, 0]
 [0, 0, 0, 0]]
相关问题