我正在使用PyTorch增强我的技能,并且想知道是否有一种方法可以执行一些Numpy可能实现的常见功能。具体来说,删除具有特定平均值或特定标准偏差的列。
在Numpy中,这很容易做到,只需找到按列的均值/标准差即可做到
columns_to_keep = np.concatenate( [stdColX != 0 ] )
X = X[:, columns_to_keep]
但是,通过PyTorch表示法,在获得“平均值/标准差”列之后,我找不到一种简单的方法。
我做完
X_colStds = X.std(0)
要获取我的PyTorch张量中每一列的标准偏差,该怎么办?如果可以避免的话,我想避免在numpy之间来回跳动。
谢谢!