首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在C++中简洁地声明和初始化多维数组

在C++中简洁地声明和初始化多维数组
EN

Stack Overflow用户
提问于 2021-12-24 23:17:26
回答 2查看 148关注 0票数 0

例如,在三维中,我通常会做一些类似的事情

代码语言:javascript
复制
vector<vector<vector<T>>> v(x, vector<vector<T>>(y, vector<T>(z, val)));

然而,对于复杂的类型和大的维度,这会变得很乏味。是否可以定义一种类型(例如,tensor ),其用法如下:

代码语言:javascript
复制
tensor<T> t(x, y, z, val1);
t[i][j][k] = val2;
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-12-24 23:29:14

使用模板元编程是可能的。

定义向量NVector

代码语言:javascript
复制
template<int D, typename T>
struct NVector : public vector<NVector<D - 1, T>> {
    template<typename... Args>
    NVector(int n = 0, Args... args) : vector<NVector<D - 1, T>>(n, NVector<D - 1, T>(args...)) {
    }
};

template<typename T>
struct NVector<1, T> : public vector<T> {
    NVector(int n = 0, const T &val = T()) : vector<T>(n, val) {
    }
};

你可以这样用它

代码语言:javascript
复制
    const int n = 5, m = 5, k = 5;
    NVector<3, int> a(n, m, k, 0);
    cout << a[0][0][0] << '\n';

我觉得很清楚它是怎么用的。我们还是说NVector<# of dimensions, type> a(lengths of each dimension separated by coma (optional)..., default value (optional))吧。

票数 3
EN

Stack Overflow用户

发布于 2021-12-25 05:39:06

另一个答案显示了用模板元编程生成向量向量的好方法。如果您希望在下面具有较少的分配和连续存储的多维数组数据结构,下面是一个示例,说明如何使用NDArray模板类包装对底层向量的访问。这可以扩展到定义非默认的operator=、复制操作符、每个维度的调试边界检查、行主存储或列主存储等,以获得额外的方便。

NDArray.h

代码语言:javascript
复制
#pragma once

#include <array>
#include <vector>

template<int N, typename ValueType>
class NDArray {
public:
    template<typename... Args>
    NDArray(Args... args)
    : dims({ args... }),
      offsets(compute_offsets(dims)),
      data(compute_size(dims), ValueType{})
    {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray dimension arguments");
    }

    void fill(ValueType val) {
        std::fill(data.begin(), data.end(), val);
    }

    template<typename... Args>
    inline void resize(Args... args) {
        static_assert(sizeof...(args) == N,
            "Incorrect number of NDArray resize arguments");
        dims = { args... };
        offsets = compute_offsets(dims);
        data.resize(compute_size(dims));
        fill(ValueType{});
    }

    template<typename... Args>
    inline ValueType operator()(Args... args) const {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray index arguments");
        return data[calc_index({ args... })];
    }

    template<typename... Args>
    inline ValueType& operator()(Args... args) {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray index arguments");
        return data[calc_index({ args... })];
    }

    int length(int axis) const { return dims[axis]; }

    const int num_dims = N;

private:
    static std::array<int, N> compute_offsets(const std::array<int, N>& dims) {
        std::array<int, N> offsets{};
        offsets[0] = 1;
        for (int i = 1; i < N; ++i) {
            offsets[i] = offsets[i - 1] * dims[i - 1];
        }
        return offsets;
    }

    static int compute_size(const std::array<int, N>& dims) {
        int size = 1;
        for (auto&& d : dims) size *= d;
        return size;
    }

    inline int calc_index(const std::array<int, N>& indices) const {
        int idx = 0;
        for (int i = 0; i < N; ++i) idx += offsets[i] * indices[i];
        return idx;
    }

    std::array<int, N> dims;
    std::array<int, N> offsets;
    std::vector<ValueType> data;
};

这会用正确的参数数覆盖operator(),如果给出了错误的参数数,它将不会编译。一些例子使用

代码语言:javascript
复制
using Array2D = NDArray<2,double>;
using Array3D = NDArray<3,double>;

auto a = Array2D(3, 6);
a.fill(1.0);
a(2, 4) = 2.0;
//a(2,4,4) will not compile
std::cout << "a = " << std::endl << a << std::endl;

a.resize(2,2);
a(1,1) = 1.2;
std::cout << "a = " << std::endl << a << std::endl;

//auto b = Array3D(4, 4); // will not compile

auto b = Array3D(4, 3, 2);
b.fill(-1.0);
b(0, 0, 0) = 4.0;
b(1, 1, 1) = 2.0;
std::cout << "b = " << std::endl << b << std::endl;

(使用2D和3D数组的辅助输出方法)

代码语言:javascript
复制
std::ostream& operator<<(std::ostream& os, const Array2D& arr) {
    for (int i = 0; i < arr.length(0); ++i) {
        for (int j = 0; j < arr.length(1); ++j) {
            os << arr(i,j) << " ";
        }
        os << std::endl;
    }
    return os;
}

std::ostream& operator<<(std::ostream& os, const Array3D& arr) {
    for (int k = 0; k < arr.length(2); ++k) {
        os << "array(:,:,"<<k<<") = " << std::endl;
        for (int i = 0; i < arr.length(0); ++i) {
            os << "  ";
            for (int j = 0; j < arr.length(1); ++j) {
                os << arr(i, j, k) << " ";
            }
            os << std::endl;
        }
        os << std::endl;
    }
    return os;
}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70477099

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档