当前位置 博文首页 > c艹用户:C++基于armadillo im2col的实现

    c艹用户:C++基于armadillo im2col的实现

    作者:c艹用户 时间:2021-05-23 18:21

    col2im的实现,这是im2col的逆过程
    最近学习CNN,需要用到im2col这个函数,无奈网上没有多少使用armadillo的例子,而且armadillo库中似乎也没有这个函数,因此自己写了。
    im2col的原理网上一大把,我懒得写了。

    1. field<某类>

    field<class oT> 是armadillo库中的类,类似于矩阵, 不过这个“矩阵”的每一个元素都是向量或者矩阵。因此用field可以作为四维输入数据使用。

    2. 矩阵展开

    这个其实还挺简单,使用reshape函数将矩阵变形。不过,armadillo中变形是按照竖向变形的。比如:

    1 2 3
    4 5 6
    7 8 9
    

    这样的矩阵变形成1×9的向量的话:

    1 4 7 2 5 8 3 6 9
    

    会成这样??。。。
    但是也不影响,滤波器也是这么变得,相对位置没变呗。。

    3. 排列组合

    鄙人才疏学浅,只会用一堆for循环来排列组合。。。貌似没找到更好的办法。

    4. 其他细节

    像是步数、填充什么的,多注意一下就行了。

    5. 实现代码

    mat im2col(field<mat> input_data, int filter_h, int filter_w, int stride, int pad)
    {
    	int N, C, H, W;
    	N = input_data.n_rows;
    	C = input_data.n_cols;
    	H = input_data(0, 0).n_rows;
    	W = input_data(0, 0).n_cols;
    	int out_h = (H + 2 * pad - filter_h) / stride + 1;
    	int out_w = (W + 2 * pad - filter_w) / stride + 1;
    	field<mat> img = input_data;
    	img.for_each([H, W, pad](mat& X) {X.insert_rows(0, pad); X.insert_rows(H + pad, pad); X.insert_cols(0, pad); X.insert_cols(W + pad, pad); });
    	mat col(out_h * out_w * N, C * filter_h * filter_w, fill::zeros);
    	for (int n = 0, z = 0; n < N; n++)
    	{
    		for (int i = 0; i < out_h; i++)
    		{
    			for (int j = 0; j < out_w; j++, z++)
    			{
    				for (int k = 0; k < C; k++)
    				{
    					mat filter(filter_h, filter_w, fill::zeros);
    					filter = img(n, k)(span(i * stride, i * stride + filter_h - 1), span(j * stride, j * stride + filter_w - 1));
    					filter.reshape(1, filter_h * filter_w);
    					int x = z;
    					int y0 = filter_h * filter_w * k;
    					int y1 = filter_h * filter_w * k + filter_h * filter_w - 1;
    					col(span(x, x), span(y0, y1)) = filter;
    				}
    			}
    		}
    	}
    	return col;
    }
    

    头文件就是声明和引用。

    bk