A White Paper on Neural Network Deployment
  • ❤️‍🔥A White Paper on Neural Network Deployment
    • ❤️‍🔥A White Paper on Neural Network Deployment
    • 🤠CUDA
      • 🤑CPU|GPU程序执行流程
      • 🤗QiuckLearnFromPicture
      • 🤒GPU编程模型
      • 🫣线程束和线程束分化|Warp
      • 🤭Reduction|并行规约
      • 🤔全局内存(Global Memory)访问模式
      • 🫢Share Memory|共享内存|Bank Conflicts
      • 😷CUDA流和事件
      • 🫡Nsight system和Nsight compute
      • 🤫Grid-Stride Loops
    • 😄ONNX
      • 😉ONNX中的各类Proto
      • 🤔onnx->torch
      • 🥳0x00自定义算子
      • 😕0x01自定义算子
      • 🥴ONNX 模型的修改与调试
      • 😆ONNX中的一些概念
      • 😍用python操作ONNX
      • 🥹ONNX中的广播机制
      • 🤣外部数据
      • 🥰ONNX Model hub
      • 😘ONNX IR(Intermediate Representation)
      • 🥳ONNX后端
      • 🥸概述
    • 🐶TensorRT
      • 🐱TensorRT快速入门指南
      • 🐭文档简介
      • 🐹TensorRT的功能
      • 🐰TensorRT的C++接口解析
      • 🦊TensorRT的Python接口解析
      • 🐻TensorRT如何工作
      • 🐼trtexec的使用
      • 🐻‍❄️实战:解析onnx模型保存为engine文件|from scratch
      • 🐨实战:加载engine文件并执行推理|from scratch
      • 🐯手撕TensoRT源码|0x00
    • 🫶模型量化和剪枝
      • 🖕IEEE754标准
      • 🫰浮点运算产生的误差
      • 🤲映射和偏移
      • 🫴quantization from scratch|python
      • 👏动态量化范围
      • 🤝量化粒度
      • 👍校准
      • 👊Post-Training Quantization
      • ✊Quantization-Aware Training
      • 🤞pytorch-quantization使用文档
      • ✌️Polygraphy-Cheatsheet
    • 🤺杂文不杂
      • 😾Roofline_model
      • 🤖模型部署的几大误区
      • 😽手算Ampere架构各个精度的Throughout
      • 😻Tensor Core VS CUDA Core
      • 😺PNNX计算图结构剖析
      • 🎃融合BN和Conv层
      • 👾深度神经网络编译器原理简介
      • 👽在WSL2上安装CUDA_cuDNN_TensorRT
    • 🍀CPP
      • 🪵lamda表达式|C++11
      • 🌴智能指针|C++11
      • 🌲右值引用|移动语义|完美转发|C++11
      • 🫑emplace_back 减少内存拷贝和移动|C++11
      • 🥬多线程|互斥锁|条件变量|C++11
      • 🥒异步操作|C++11
      • 🍆原子变量|CAS操作|内存顺序|C++11
      • 🍏对象生存期和资源管理|RAII设计思想
      • 🍎Pimpl设计模式|编译防火墙
      • 🌶️std::variant|C++17
      • 🫛std::any|C++17
    • 🩷部署实战
      • ❤️yolov8Multitask
      • 💚yolov5
      • 🧡pointpillars
      • 💛centerpoint
      • 🩵deepstream
      • 💜BEVfusion
      • 💙BEVLane
      • 🖤Occupancy
    • ☯️重点参考书籍
Powered by GitBook
On this page

Was this helpful?

Edit on GitHub
  1. A White Paper on Neural Network Deployment
  2. 杂文不杂

融合BN和Conv层

PreviousPNNX计算图结构剖析Next深度神经网络编译器原理简介

Last updated 1 year ago

Was this helpful?

层融合可以减少启动kernel的开销与memory操作,从而提高效率 同时,有些计算可以通过层融合优化后,跟其他计算合并

  • Vertical layer fusion (垂直层融合):用的比较常见,针对conv + BN + ReLU进行融合

  • Horizontal layer fusion (水平层融合)

回顾一下Batch Normalization的公式,其中

μB=1B∑i=1BxiσB2=1B∑i=1B(xiμB)2+ϵxi^=xi−μBσB2+ϵyi=γ∗xi^+β→展开yi=γ∗xi−μBσB2+ϵ+β代入xi=w∗x+by=γ∗w∗x+b−μBσB2+ϵ+β\begin{aligned}\mu_B&=\frac{1}{B}\sum_{i=1}^{B}x_i\\\sigma_B^2&=\frac{1}{B}\sum_{i=1}^{B}(x_i\mu_B)^2+\epsilon\end{aligned}\begin{aligned}\widehat{x_i}&=\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\epsilon}} \quad \\ \quad y_i&=\gamma*\widehat{x_i}+\beta\end{aligned}\xrightarrow{\text{展开}} \quad y _ i = \gamma * \frac { x _ i - \mu _ B }{ \sqrt { \sigma _ B ^ 2 + \epsilon }}+\beta \frac{\text{代入}}{x_i=w*x+b}\quad y=\gamma*\frac{w*x+b-\mu_B}{\sqrt{\sigma_B^2+\epsilon}}+\betaμB​σB2​​=B1​i=1∑B​xi​=B1​i=1∑B​(xi​μB​)2+ϵ​xi​​yi​​=σB2​+ϵ​xi​−μB​​=γ∗xi​​+β​展开​yi​=γ∗σB2​+ϵ​xi​−μB​​+βxi​=w∗x+b代入​y=γ∗σB2​+ϵ​w∗x+b−μB​​+β
y=γ∗w∗x+b−μBσB2+ϵ+βy=γ∗wσB2+ϵ∗x+γσB2+ϵ(b−μB)+βy=(γ∗wσB2+ϵ)∗x+(γσB2+ϵ(b−μB)+β)\begin{aligned} &y=\gamma*\frac{w*x+b-\mu_{B}}{\sqrt{\sigma_{B}^{2}+\epsilon}}+\beta \\ &y=\frac{\gamma^{*}w}{\sqrt{\sigma_{B}^{2}+\epsilon}}*x+\frac{\gamma}{\sqrt{\sigma_{B}^{2}+\epsilon}}(b-\mu_{B})+\beta \\ &y=(\frac{\gamma*w}{\sqrt{\sigma_{B}^{2}+\epsilon}})*x+(\frac{\gamma}{\sqrt{\sigma_{B}^{2}+\epsilon}}(b-\mu_{B})+\beta) \end{aligned}​y=γ∗σB2​+ϵ​w∗x+b−μB​​+βy=σB2​+ϵ​γ∗w​∗x+σB2​+ϵ​γ​(b−μB​)+βy=(σB2​+ϵ​γ∗w​)∗x+(σB2​+ϵ​γ​(b−μB​)+β)​
w^=γ∗wσB2+ϵb^=γσB2+ϵ(b−μB)+β\begin{array}{l}\widehat{w}=\dfrac{\gamma*w}{\sqrt{\sigma_B^2+\epsilon}}\\\widehat{b}=\dfrac{\gamma}{\sqrt{\sigma_B^2+\epsilon}}(b-\mu_B)+\beta\end{array}w=σB2​+ϵ​γ∗w​b=σB2​+ϵ​γ​(b−μB​)+β​

这两个参数值可以提前计算出来.

很多模型经常会有很多 种类的activation function,比如 GELU, Swish, Mish等等,这些激活函数往往由于计算复杂很难加速,可以尝试改成ReLU看看精度的改变和性能的提升.

把上面的公式写成矩阵的形式:

对于一个形状为C×H×W的特征图F,记归一化结果F^,计算如下:对于一个形状为C\times H\times W的特征图F,记归一化结果\hat{F},计算如下:对于一个形状为C×H×W的特征图F,记归一化结果F^,计算如下:
(F^1,i,jF^2,i,j⋮F^C−1,i,jF^C,i,j)=(γ1σ12+ϵ0⋯00γ2σ22+ϵ⋮⋱⋮γC−1σC−12+ϵ00⋯0γCσC2+ϵ)⋅(F1,i,jF2,i,j⋮FC−1,i,jFC,i,j)+(β1−γ1μ^1σ12+ϵβ2−γ2μ^2σ22+ϵ⋮βC−1−γC−1μ^C2+ϵσC2+ϵ)\begin{aligned}&\begin{pmatrix}\hat{F}_{1,i,j}\\\hat{F}_{2,i,j}\\\vdots\\\hat{F}_{C-1,i,j}\\\hat{F}_{C,i,j}\end{pmatrix}=\begin{pmatrix}\frac{\gamma_1}{\sqrt{\sigma_1^2+\epsilon}}&0&\cdots&0\\0&\frac{\gamma_2}{\sqrt{\sigma_2^2+\epsilon}}\\\vdots&&\ddots&\vdots\\&&\frac{\gamma_{C-1}}{\sqrt{\sigma_{C-1}^2+\epsilon}}&0\\0&\cdots&0&\frac{\gamma_C}{\sqrt{\sigma_C^2+\epsilon}}\end{pmatrix}\cdot\begin{pmatrix}F_{1,i,j}\\F_{2,i,j}\\\vdots\\F_{C-1,i,j}\\F_{C,i,j}\end{pmatrix}+\begin{pmatrix}\beta_1-\gamma_1\frac{\hat{\mu}_1}{\sqrt{\sigma_1^2+\epsilon}}\\\beta_2-\gamma_2\frac{\hat{\mu}_2}{\sqrt{\sigma_2^2+\epsilon}}\\\vdots\\\beta_{C-1}-\gamma_{C-1}\frac{\hat{\mu}_C^2+\epsilon}{\sqrt{\sigma_C^2+\epsilon}}\end{pmatrix}\end{aligned}​​F^1,i,j​F^2,i,j​⋮F^C−1,i,j​F^C,i,j​​​=​σ12​+ϵ​γ1​​0⋮0​0σ22​+ϵ​γ2​​⋯​⋯⋱σC−12​+ϵ​γC−1​​0​0⋮0σC2​+ϵ​γC​​​​⋅​F1,i,j​F2,i,j​⋮FC−1,i,j​FC,i,j​​​+​β1​−γ1​σ12​+ϵ​μ^​1​​β2​−γ2​σ22​+ϵ​μ^​2​​⋮βC−1​−γC−1​σC2​+ϵ​μ^​C2​+ϵ​​​​

代码如下:

def fuse_conv_and_bn(conv, bn):
	#
	# init
	fusedconv = torch.nn.Conv2d(
		conv.in_channels,
		conv.out_channels,
		kernel_size=conv.kernel_size,
		stride=conv.stride,
		padding=conv.padding,
		bias=True
	)
	#
	# prepare filters
	w_conv = conv.weight.clone().view(conv.out_channels, -1)
	w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
	fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) )
	#
	# prepare spatial bias
	if conv.bias is not None:
		b_conv = conv.bias
	else:
		b_conv = torch.zeros( conv.weight.size(0) )
	b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
	fusedconv.bias.copy_( torch.matmul(w_bn, b_conv) + b_bn )
	#
	# we're done
	return fusedconv

if __name__ == "__main__":
    import torch
    import torchvision

    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    rn18 = torchvision.models.resnet18(pretrained=True)
    rn18.eval()
    net = torch.nn.Sequential(
        rn18.conv1,
        rn18.bn1
    )
    y1 = net.forward(x)
    fusedconv = fuse_conv_and_bn(net[0], net[1])
    y2 = fusedconv.forward(x)
    d = (y1 - y2).norm().div(y1.norm()).item()
    print("error: %.8f" % d)
# error: 0.00000030
❤️‍🔥
🤺
🎃