找到theano中的最小N个元素

时间:2015-03-19 20:29:11

标签: python-2.7 theano

我有一个theano函数可以计算2个矩阵的欧氏距离 - Xn vectors x k features)和Ym vectors x k features)。结果是n x m中每个向量(或行)X中每个向量(或行)的成对距离的Y矩阵。

import theano
from theano import tensor as T

X, Y = T.dmatrices('X', 'Y')
X_squared_sum = T.sum(X ** 2, axis=1, keepdims=True)
Y_squared_sum = T.sum(Y.T ** 2, axis=0, keepdims=True)
squared_distances = X_squared_sum + Y_squared_sum - 2 * T.dot(X, Y.T)
f_distance = theano.function([X, Y], T.sqrt(squared_distances))

假设我将上述函数更改为接受单个向量,向量数组和最小距离数。我想要的是一个可以找到N个最小距离的theano函数,类似于下面:

import numpy as np
import theano
from theano import tensor as T

X = T.dvector('X')
Y = T.dmatrix('Y')
N = T.iscalar('N')
X_squared_sum = T.dot(X, X)
Y_squared_sum = T.sum(Y.T ** 2, axis=0)
squared_distances = X_squared_sum + Y_squared_sum - 2 * T.dot(X, Y.T)
dist_sorted = T.FIND_N_SMALLEST(T.sqrt(squared_distances), N)
n_closest = theano.function([X, Y, N], dist_sorted)

U = np.array([[1, 1, 1, 1]])
V = np.array([
      [  4,   4,   4,   4],
      [  2,   2,   2,   2],
      [  3,   3,   3,   3],
      [  1,   1,   1,   1]])

n_closest(U, V, 2)  # [0.0, 2.0]

我想避免明确地对所有距离进行排序,因为我想要的数字通常远远小于总距离。

0 个答案:

没有答案