如何对此代码进行矢量化?

时间:2017-10-04 12:54:10

标签: matlab vectorization

我写了一个递归函数,然而,它花了很多时间。因此我对它进行了矢量化,但它不会产生与递归函数相同的结果。这是我的非矢量化代码:

function visited = procedure_explore( u, adj_mat, visited )
visited(u) = 1;
neighbours = find(adj_mat(u,:));
for ii = 1:length(neighbours)
    if (visited(neighbours(ii)) == 0)
        visited = procedure_explore( neighbours(ii), adj_mat, visited );
    end
end
end

这是我的矢量化代码:

function visited = procedure_explore_vec( u, adj_mat, visited )
visited(u) = 1;
neighbours = find(adj_mat(u,:));
len_neighbours=length(neighbours);
visited_neighbours_zero=visited(neighbours(1:len_neighbours)) == 0;
if(~isempty(visited_neighbours_zero))
    visited = procedure_explore_vec( neighbours(visited_neighbours_zero), adj_mat, visited );
end
end

这是测试代码

function main
    adj_mat=[0 0 0 0;
             1 0 1 1;
             1 0 0 0;
             1 0 0 1];
    u=2;
    visited=zeros(size(adj_mat,1));
    tic
    visited = procedure_explore( u, adj_mat, visited )
    toc
    visited=zeros(size(adj_mat,1));
    tic
    visited = procedure_explore_vec( u, adj_mat, visited )
    toc
end

这是我尝试实施的算法: enter image description here

如果不能进行矢量化,那么mex解决方案也会很好。

更新基准:此基准测试基于MATLAB 2017a。它表明原始代码比其他方法更快

Speed up between original and logical methods is 0.39672
Speed up between original and nearest methods is 0.0042583

完整代码

function main_recersive
    adj_mat=[0 0 0 0;
             1 0 1 1;
             1 0 0 0;
             1 0 0 1];
    u=2;
    visited=zeros(size(adj_mat,1));
    f_original=@()(procedure_explore( u, adj_mat, visited ));
    t_original=timeit(f_original);

    f_logical=@()(procedure_explore_logical( u, adj_mat ));
    t_logical=timeit(f_logical);

    f_nearest=@()(procedure_explore_nearest( u, adj_mat,visited ));
    t_nearest=timeit(f_nearest);

    disp(['Speed up between original and logical methods is ',num2str(t_original/t_logical)])
    disp(['Speed up between original and nearest methods is ',num2str(t_original/t_nearest)])    

end

function visited = procedure_explore( u, adj_mat, visited )
    visited(u) = 1;
    neighbours = find(adj_mat(u,:));
    for ii = 1:length(neighbours)
        if (visited(neighbours(ii)) == 0)
            visited = procedure_explore( neighbours(ii), adj_mat, visited );
        end
    end
end

function visited = procedure_explore_nearest( u, adj_mat, visited )
    % add u since your function also includes it.
    nodeIDs = [nearest(digraph(adj_mat),u,inf) ; u];
    % transform to output format of your function
    visited = zeros(size(adj_mat,1));
    visited(nodeIDs) = 1;

end 

function visited = procedure_explore_logical( u, adj_mat )
   visited = false(1, size(adj_mat, 1));
   visited(u) = true;
   new_visited = visited;
   while any(new_visited)
      visited = any([visited; new_visited], 1);
      new_visited = any(adj_mat(new_visited, :), 1);
      new_visited = and(new_visited, ~visited);
   end
end

6 个答案:

答案 0 :(得分:4)

这是一个有趣的小功能,可以在图表上进行非递归的广度优先搜索。

function visited = procedure_explore_logical( u, adj_mat )
   visited = false(1, size(adj_mat, 1));
   visited(u) = true;
   new_visited = visited;

   while any(new_visited)
      visited = any([visited; new_visited], 1);
      new_visited = any(adj_mat(new_visited, :), 1);
      new_visited = and(new_visited, ~visited);
   end
end

在Octave中,这比100x100邻接矩阵上的递归版本快50倍。你必须在MATLAB上对它进行基准测试,看看你得到了什么。

答案 1 :(得分:2)

您可以将邻接矩阵视为长度恰好为1的路径列表。您可以通过将其加到矩阵的第n个幂来生成其他长度为n的路径。 (adj_mat ^ 0是单位矩阵)

在具有n个节点的图表中,最长路径不能长于n-1,因此您可以将所有权力相加以进行可达性分析:

adj_mat + adj_mat^2 + adj_mat^3
ans =
   0   0   0   0
   4   0   1   3
   1   0   0   0
   3   0   0   3

这是您可以用来从一个节点转到另一个节点的(不同)方式的数量。对于简单的反应性,请检查此值是否大于零:

visited(v) = ans(v, :) > 0;

根据您的定义,您可能必须更改结果中的列和行(即采用ans(:,v))。

为了提高性能,您可以使用较低的功率来制作更高的功率。例如,可以有效地计算类似A + A ^ 2 + A ^ 3 + A ^ 4 + A ^ 5的东西:

A2 = A^2;
A3 = A2*A
A4 = A2^2;
A5 = A4*A;
allWalks= A + A2 + A3 + A4 + A5;

注意:如果要将原始节点包含为可从其自身访问,则应在总和中包含单位矩阵。

这最大限度地减少了矩阵乘法的数量,MATLAB也可能比常规乘法更快地执行矩阵平方。

根据我的经验,矩阵乘法在MATLAB中相对较快,这将立即产生图中所有节点的结果(可达性)向量。如果您只对大图的一小部分感兴趣,这可能不是最好的解决方案。

另见答案:https://stackoverflow.com/a/7276595/1974021

答案 2 :(得分:1)

我认为你无法正确地实现你的功能:你的原始功能永远不会多次到达同一个节点。通过矢量化,您可以将所有直接连接的节点同时传递给下一个函数。因此,在以下实例中可能会多次到达同一节点。例如。在您的示例节点1中将达到3次。因此,当您不再有循环时,根据您的网络,该函数可能会递归调用更多次,这会增加计算时间。

话虽如此,通常不可能在没有循环或递归调用的情况下找到所有可到达的节点。例如,您可以检查所有(有效或无效)路径。但是这与您的功能有很大不同,并且根据节点的数量,可能会由于要检查的路径数量惊人而导致性能下降。您当前的功能并不太糟糕,并且可以在大型网络中很好地扩展。

有点offtopic,但自Matlab 2016a以来,您可以使用nearest()查找所有可到达的节点(没有起始节点)。与深度优先算法相比,它调用广度优先算法:

% add u since your function also includes it.
nodeIDs = [nearest(digraph(adj_mat),u,inf) ; u]; 

% transform to output format of your function
visited = zeros(size(adj_mat,1));
visited(nodeIDs) = 1;

如果这是针对学生项目的,你可以争辩说,虽然你的功能有效,但出于性能原因你使用了内置功能。<​​/ p>

答案 3 :(得分:1)

递归函数的问题与visited(u) = 1;有关。这是因为MATLAB使用写时​​复制技术来传递/分配变量。如果你没有在函数体中更改visited,则不会复制它,但是当它被修改时,会创建它的副本并对其副本进行修改。为了防止这种情况,您可以通过引用该函数来传递handle object

定义句柄类(将其保存到visited_class.m):

classdef visited_class < handle
    properties
        visited
    end
    methods
        function obj = visited_class(adj_mat)
            obj.visited = zeros(1, size(adj_mat,1));
        end
    end
end

递归函数:

function procedure_explore_handle( u, adj_mat,visited_handle )
    visited_handle.visited(u) = 1;
    neighbours = find(adj_mat(u,:));
    for n = neighbours
        if (visited_handle.visited(n) == 0)
            procedure_explore_handle( n, adj_mat , visited_handle );
        end
    end
end

初始化变量:

adj_mat=[0 0 0 0;
         1 0 1 1;
         1 0 0 0;
         1 0 0 1];
visited_handle = visited_class(adj_mat);
u = 2;

将其命名为:

procedure_explore_handle( u, adj_mat,visited_handle );

结果已保存到visited_handle

disp(visited_handle.visited)

答案 4 :(得分:0)

如果你想从图中的一个点转到另一个点,那么在资源方面找到它的最有效方法是Dijkstra的算法。 Floyd-Warshall算法计算所有点之间的所有距离,并且可以并行化(从多个点开始)。

为什么需要矢量化(或使用mex编程)?如果您只是想充分利用Matlab的快速矩阵乘法程序,那么使用A的产品应该可以快速到达那里:

adj_mat2=adj_mat^2;               % allowed to use 2 steps
while (adj_mat2 ~= adj_mat)       % check if new points were reached
      adj_mat=adj_mat2;           % current set of reachable points
      adj_mat2=(adj_mat^2)>0;     % allowed all steps again: power method
end

答案 5 :(得分:0)

这个答案只是给出DasKrümelmonster's answer建议的一个明确的,矢量化的实现,我认为它比问题中的代码更快(至少如果矩阵的维度不是太大)。它使用polyvalm函数来评估邻接矩阵的幂的总和。

function visited = procedure_explore_vec(u, adj_mat)
    connectivity_matrix = polyvalm(ones(size(adj_mat,1),1),adj_mat)>0;

    visited = connectivity_matrix(u,:);
end
相关问题