【算法分析】Reduce类算子-ArgMax解析

本文深入讲解了Reduce算子的工作原理及应用,特别是针对三维矩阵在指定维度上进行ArgMax计算的过程,并提供了详细的算法示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

 

什么是Reduce算子

理解Reduce的过程

理解Reduce的算法

3维矩阵示例

CPU算法示例

总结

TBD


什么是Reduce算子

常见的如TensorFlow的tf.reduce_sum、MNN框架中的ArgMax/ArgMin,概况起来说就是对输入的多维张量(Tensor)数据,在某一维上执行特定的计算(比如sum、求Max/Min),从而达到降低维度的目的,下图先展示了一个典型的Reduce算子计算过程,本文围绕该过程对算法做详细说明

3维矩阵(N,H,W) 在H维度方向做Reduce

理解Reduce的过程

如上图所示,2维矩阵(n,m) 有x和y两个维度的方向,在x维度做reduce计算得到的结果是1维向量(n)即x维度reduce为1仅剩下y维度;在y维度做reduce计算得到的结果是1位向量(m)即y维度reduce为1仅剩下x维度;

通常在机器学习或者并行计算中采用Axis来指定输入矩阵的各个维度,把上图的例子泛化后描述:

  • 输入的张量数据A 其中 rank(A)=n,(D0,D1,... Dn-1)其中Di表示Axis=i的维数;Reduce(A,Axis=i)= B 其中rank(B)=n-1, (D0, D1 ... , Di-1, Di+1, ... Dn-1)即Di被reduce为1
  • 所以2维的矩阵在某个Axis上Reduce后成为1维矩阵,3维矩阵在某个Axis上Reduce后成为2维矩阵,丢失的那一维即是执行Reduce计算的对应Axis

 

理解ArgMax的算法

3维矩阵示例

Reduce的算法很多,这里以开源推理框架MNNArgMax算子为例说明Reduce算子的一种ArgMax算子是如何在CPU上实现的(代码链接

代码的本质是实现算法,所以在解释代码前先以3维矩阵作为例子把算法的思路梳理如下:

3维矩阵与内存中的布局

求ArgMax(A, Axis=1)可以得到B(N,W),整个计算过程可以参考“什么是Reduce算子”章节的动画

沿Axis=1方向计算ArgMax

求解B(N,W)矩阵中的元素bij:

参考上图,2维的坐标(i,j)如果不容易理解,其实将reduce为1的Axis=1的轴(H维度方向)加上,扩展成3维作为(i,1,j)就可以看到,bij 其实是对第i个输入矩阵,在第j列上沿着Axis=1的轴(H维度方向)求ArgMax做Reduce计算

计算访问A的内存地址

按照输入矩阵A(N,H,W)在内存中按行优先存储的布局,索引到第i个输入矩阵的内存起始地址为: A + i*H*W ; 索引到第i个输入矩阵的第j列的内存起始地址为: A + i*H*W + j;索引到第i个输入矩阵的第j列的内存结束地址为:A + i*H*W + (H-1)*W + j ; 第i个输入矩阵的第j列的相邻两个元素之间的offset为 W

算法过程示意

所以如上图所示,整个算法的核心思路就是3个for循环依次遍历所有输入矩阵,每个输入矩阵的所有列,每一列的所有行,最后计算出B的每个元素bij

CPU算法示例

泛化上面的3维矩阵实例到一般情况,算法的CPU版本实现(代码链接)和描述如下

输入的张量数据A 其中 rank(A)=n,(D0,D1,... Dn-1)求ArgMax(A, Axis=i)

根据“理解Reduce的过程”章节可以知道计算结果=B,其中rank(B)=n-1, (D0, D1 ... , Di-1, Di+1, ... Dn-1)Di被reduce为1

核心算法的代码很短(仅考虑输入是NHWC的格式),但要看明白需要花一番功夫,这里解析如下

//求解ArgMax(input, axis),求输入张量在axis轴上做ArgMax计算的结果

        if (mMode == ARGMAX) {
            auto srcOrigin = input->host<float>();//获得输入A的起始地址
            auto dstOrigin = output->host<int>();//获得输出B的起始地址
            for (int i = 0; i < mNum; ++i) {//第1个for循环,遍历所有输入的高维矩阵
                auto iptr = srcOrigin + i * mDim * mKeyExtent;//计算A中第i个高维矩阵的起始地址
                auto optr = dstOrigin + i * mKeyExtent;//计算B中存放第i个高维矩阵ArgMax计算结果的起始地址

                for(int k = 0; k < mKeyExtent; ++k){//第2个for循环,遍历第i个高维矩阵在axis轴之后所有的维数
                    int index      = 0;
                    float maxValue = -FLT_MAX;
                    for (int j = 0; j < mDim; ++j) {//第3个for循环,遍历axis轴上的每个元素做ArgMax计算
                        auto val = iptr[k + j * mKeyExtent];
                        if (val > maxValue) {
                            maxValue = val;
                            index    = j;
                        }
                    }
                    optr[k] = index;//保存计算结果到B
                }
            }

该算法最难理解的部分是3个for循环的条件参数,说明如下:

        const int dimensions = input->dimensions();//获得A的rank
        for (int i = 0; i < mAxis; ++i) {
            mNum = mNum * input->length(i);//计算A的axis轴之前所有的总维数,即axis轴之前所有输入的高维矩阵的总个数(相当于3维矩阵实例中的N)
        }
        mDim = input->length(mAxis);//进行ArgMax计算的axis轴的维数(相当于3维矩阵实例中的H)
        for (int i = mAxis + 1; i < dimensions; ++i) {
            mKeyExtent = mKeyExtent * input->length(i);//计算A的axis轴之后的总维数,即对于第i个高维矩阵,生成多少个计算结果(相当于3维矩阵实例中的W)
        }

 

总结

张量数据A 其中 rank(A)=n,(D0,D1,... Dn-1)求 ArgMax(A, axis=i) 的算法可以概况为3点

  • 求解结果B 其中rank(B)=n-1, (D0, D1 ... , Di-1, Di+1, ... Dn-1)Di被Reduce为1并忽略(如果是Reduce为topK个元素,那么Di维度降为topK)
  • 遍历axis=i 轴之前的所有维数 \prod_{k=0}^{i-1}Dk 作为输入高维矩阵
  • 每个输入高维矩阵可以计算出 \prod_{k=i+1}^{n-1}Dk 个ArgMax结果,其中每个ArgMax结果是通过遍历axis=i轴上Di 个元素来计算

 

TBD

引入CUDA 并行计算后GPU版本的算法和它的计算过程会变成什么样子

 

资源下载链接为: https://pan.quark.cn/s/1bfadf00ae14 “STC单片机电压测量”是一个以STC系列单片机为基础的电压检测应用案例,它涵盖了硬件电路设计、软件编程以及数据处理等核心知识点。STC单片机凭借其低功耗、高性价比和丰富的I/O接口,在电子工程领域得到了广泛应用。 STC是Specialized Technology Corporation的缩写,该公司的单片机基于8051内核,具备内部振荡器、高速运算能力、ISP(在系统编程)和IAP(在应用编程)功能,非常适合用于各种嵌入式控制系统。 在源代码方面,“浅雪”风格的代码通常简洁易懂,非常适合初学者学习。其中,“main.c”文件是程序的入口,包含了电压测量的核心逻辑;“STARTUP.A51”是启动代码,负责初始化单片机的硬件环境;“电压测量_uvopt.bak”和“电压测量_uvproj.bak”可能是Keil编译器的配置文件备份,用于设置编译选项和项目配置。 对于3S锂电池电压测量,3S锂电池由三节锂离子电池串联而成,标称电压为11.1V。测量时需要考虑电池的串联特性,通过分压电路将高电压转换为单片机可接受的范围,并实时监控,防止过充或过放,以确保电池的安全和寿命。 在电压测量电路设计中,“电压测量.lnp”文件可能包含电路布局信息,而“.hex”文件是编译后的机器码,用于烧录到单片机中。电路中通常会使用ADC(模拟数字转换器)将模拟电压信号转换为数字信号供单片机处理。 在软件编程方面,“StringData.h”文件可能包含程序中使用的字符串常量和数据结构定义。处理电压数据时,可能涉及浮点数运算,需要了解STC单片机对浮点数的支持情况,以及如何高效地存储和显示电压值。 用户界面方面,“电压测量.uvgui.kidd”可能是用户界面的配置文件,用于显示测量结果。在嵌入式系统中,用
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值