如何改善numpy的广播

时间:2016-05-01 01:23:03

标签: numpy machine-learning mnist numpy-broadcasting mahalanobis

我尝试使用mahalanobis在蟒蛇中使用numpy来实现k-NN。但是,当我使用广播时,下面的代码工作得非常慢。 请教我如何提高numpy速度或更好地实现这一点。

from __future__ import division
from sklearn.utils import shuffle
from sklearn.metrics import f1_score
from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split

import numpy as np
import matplotlib.pyplot as plt

mnist = fetch_mldata('MNIST original')
mnist_X, mnist_y = shuffle(mnist.data, mnist.target.astype('int32'))

mnist_X = mnist_X/255.0

train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y, test_size=0.2)

k = 2
def data_gen(n):
    return train_X[train_y == n]
train_X_num = [data_gen(i) for i in range(10)]
inv_cov = [np.linalg.inv(np.cov(train_X_num[i], rowvar=0)+np.eye(784)*0.00001) for i in range(10)]  # Making Inverse covariance matrices
for i in range(10):
    ivec = train_X_num[i]  # ivec size is (number of 'i' data, 784)
    ivec = ivec - test_X[:, np.newaxis, :]  # This code is too much slowly, and using huge memory
    iinv_cov = inv_cov[i]
    d[i] = np.add.reduce(np.dot(ivec, iinv_cov)*ivec, axis=2).sort(1)[:, :k+1]  # Calculate x.T inverse(sigma) x, and extract k-minimal distance

0 个答案:

没有答案
相关问题