通过引用可变参数函数传递N-D数组

时间:2018-07-22 21:41:26

标签: c++ multidimensional-array variadic-templates

我想使函数multi_dimensional通过引用接受多维数组。

是否可以通过以下语法的变体来完成此操作,该语法可用于three_dimensional

#include <utility>

// this works, but number of dimensions must be known (not variadic)
template <size_t x, size_t y, size_t z>
void three_dimensional(int (&nd_array)[x][y][z]) {}

// error: parameter packs not expanded with ‘...’
template <size_t... dims>
void multi_dimensional(int (&nd_array)[dims]...) {}

int main() {
    int array[2][3][2] = {
        { {0,1}, {2,3}, {4,5} },
        { {6,7}, {8,9}, {10,11} }
    };
    three_dimensional(array); // OK
    // multi_dimensional(array); // error: no matching function
    return 0;
}

1 个答案:

答案 0 :(得分:2)

主要问题是您不能使数组维数本身可变。因此,无论采用哪种方式,您几乎都肯定需要某种递归方法来处理各个数组层。这种方法的确切外观将主要取决于一旦将数组提供给您之后,您打算如何处理该数组。

如果实际上您想要的只是一个可以被赋予任何多维数组的函数,那么只要编写一个可以被赋予only exists以外的任何东西的函数,只要任何东西都是一个数组即可:

template <typename T>
std::enable_if_t<std::is_array_v<T>> multi_dimensional(T& a)
{
    constexpr int dimensions = std::rank_v<T>;

    // ...
}

但是,这本身很可能不会使您走得太远。为了对给定的数组进行有意义的事情,您很可能需要在子数组中进行一些递归遍历。除非您真的只想查看结构的最顶层。

另一种方法是使用递归模板剥离各个数组级别,例如:

// we've reached the bottom
template <typename T, int N>
void multi_dimensional(T (&a)[N])
{
    // ...
}

// this matches any array with more than one dimension
template <typename T, int N, int M>
void multi_dimensional(T (&a)[N][M])
{
    // peel off one dimension, invoke function for each element on next layer
    for (int i = 0; i < N; ++i)
        multi_dimensional(a[i]);
}

但是,我建议至少考虑使用std::array<>而不是原始数组,因为原始数组的语法和特殊行为往往会立即使所有内容变成混乱的混乱。通常,可能值得实现自己的多维数组类型,例如NDArray<int, 2, 3, 2>,它在内部使用扁平化表示形式,并且仅将多维索引映射到线性索引。这种方法的优点(除了更简洁的语法之外)是您可以轻松地更改映射,例如从行主布局切换到列主布局,例如,以进行性能优化……

更新:计算扁平化的 n D索引

要实现具有静态维的常规 n D数组,我将引入一个助手类,以封装从 n D索引进行的线性索引的递归计算: / p>

template <std::size_t... D>
struct row_major;

template <std::size_t D_n>
struct row_major<D_n>
{
    static constexpr std::size_t SIZE = D_n;

    std::size_t operator ()(std::size_t i_n) const
    {
        return i_n;
    }
};

template <std::size_t D_1, std::size_t... D_n>
struct row_major<D_1, D_n...> : private row_major<D_n...>
{
    static constexpr std::size_t SIZE = D_1 * row_major<D_n...>::SIZE;

    template <typename... Tail>
    std::size_t operator ()(std::size_t i_1, Tail&&... tail) const
    {
        return i_1 + D_1 * row_major<D_n...>::operator ()(std::forward<Tail>(tail)...);
    }
};

然后:

template <typename T, std::size_t... D>
class NDArray
{
    using memory_layout_t = row_major<D...>;

    T data[memory_layout_t::SIZE];

public:
    template <typename... Args>
    T& operator ()(Args&&... args)
    {
        memory_layout_t memory_layout;
        return data[memory_layout(std::forward<Args>(args)...)];
    }
};


NDArray<int, 2, 3, 5> arr;

int main()
{
    int x = arr(1, 2, 3);
}
相关问题