🥳0x00自定义算子

⚠️This post is based on the official MMDeploy tutorial, with some minor modifications for clarity and context.

创建Pytorch模型

用下面的代码来创建一个经典的超分辨率模型 SRCNN

import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out
        
# Download checkpoint and test image
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
    'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url, name in zip(urls, names):
    if not os.path.exists(name):
        open(name, 'wb').write(requests.get(url).content)

def init_torch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)
    
    state_dict = torch.load('srcnn.pth')['state_dict']
    
    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model

model = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img)).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch.png", torch_output)

SRCNN 先把图像上采样到对应分辨率,再用 3 个卷积层处理图像。为了方便起见,我们跳过训练网络的步骤,直接下载模型权重(由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样,我们修改了权重字典的 key 来适配我们定义的模型),同时下载好输入图片。为了让模型输出成正确的图片格式,我们把模型的输出转换成 HWC 格式,并保证每一通道的颜色值都在 0~255 之间。如果脚本正常运行的话,一幅超分辨率的人脸照片会保存在 “face_torch.png” 中。

在 PyTorch 模型测试正确后,我们来正式开始部署这个模型。我们下一步的任务是把 PyTorch 模型转换成用中间表示 ONNX 描述的模型。

假设 state_dict 中的键如下:

{
    'module.conv1.weight': tensor(...),
    'module.conv1.bias': tensor(...),
    'module.fc1.weight': tensor(...),
    'module.fc1.bias': tensor(...),
}

经过

# Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

这段代码后,state_dict 会变成:

{
    'conv1.weight': tensor(...),
    'conv1.bias': tensor(...),
    'fc1.weight': tensor(...),
    'fc1.bias': tensor(...),
}

键名中的 'module.' 前缀被移除了。

构造函数

torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)

参数解释:

  1. sizetuple or int, optional):

    • 这个参数指定输出张量的目标尺寸。可以传入一个整数(对所有维度进行相同的上采样)或一个元组,元组的长度应等于输入张量的空间维度数。例如,对于图像数据,元组通常是 (H_out, W_out)

    • 如果指定了 size,就会强制将输出张量的尺寸调整为该值。

  2. scale_factorfloat or tuple of float, optional):

    • 表示上采样的倍率,即输入张量在每个维度上扩大多少倍。如果传入一个浮点数,它将作用于每个维度。如果传入一个元组,则每个元素分别对应每个空间维度的缩放系数。

    • 通常会指定 scale_factor 来表示整体的放大倍数。例如,scale_factor=2 会让输入张量的每个空间维度放大两倍。

  3. modestr, optional, default='nearest'):

    • 用于指定上采样时使用的插值方法。PyTorch 提供了以下几种插值方式:

      • 'nearest':最近邻插值。

      • 'linear':线性插值(适用于 3D 张量)。

      • 'bilinear':双线性插值(适用于 4D 张量,如图像数据)。

      • 'bicubic':双三次插值(仅适用于 4D 张量)。

      • 'trilinear':三线性插值(适用于 5D 张量,如 3D 体积数据)。

      • 'area':区域插值,根据输入大小和输出大小执行平均池化。

  4. align_cornersbool, optional, default=None):

    • 该参数用于控制当插值时如何对齐角点。对于 bilineartrilinearbicubic 插值方式,它控制插值的行为。

    • 如果 align_corners=True,输入张量的角像素会与输出张量的角像素对齐,避免边界像素的平滑效果。

    • 如果 align_corners=False,输入和输出张量的角像素不对齐,默认会使得插值的效果更加平滑。通常推荐将其设为 False

    • 对于 nearest 插值,该参数没有作用。

  5. recompute_scale_factorbool, optional, default=None):

    • 控制是否在计算输出大小时重新计算缩放因子(scale factor)。通常,PyTorch会将 scale_factor 直接用于大小计算,但在某些情况下,如果需要更准确的输出尺寸,可以设置 recompute_scale_factor=True

注意:

  • sizescale_factor 不能同时指定;必须二选一。

  • mode 指定插值方法,默认是最近邻插值,插值方式会影响到生成的张量的平滑度与细节恢复。

上采样的工作原理

上采样的基本思想是通过插值或复制来扩大输入张量的空间尺寸。最常见的插值方法有:

  • 最近邻插值nearest):简单地将最近的像素值复制到更大的空间中,不会产生平滑过渡。速度较快,通常用于快速上采样。

  • 双线性插值bilinear):基于相邻像素的线性关系插值,生成平滑的过渡效果。通常在处理图像时使用,能够提供更平滑的上采样效果。

  • 双三次插值bicubic):考虑更大范围的像素进行插值,生成更细腻的上采样效果,代价是计算复杂度稍高。

  • 三线性插值trilinear):用于5维数据(如3D体积数据)的插值方式。

align_corners 的解释与使用

align_corners 控制插值时角点的对齐方式,它主要影响到双线性、三线性和双三次插值。该参数的主要影响体现在输出结果的平滑程度上。

  • align_corners=True:会将输入张量的角点像素对齐到输出张量的角点位置,这可能会造成插值结果在边缘的像素值变化较大。

    示例:

    upsample_bilinear = nn.Upsample(size=(50, 50), mode='bilinear', align_corners=True)
  • align_corners=False(默认):输入和输出张量的角点不对齐。这个选项通常能提供更平滑的插值效果,并且PyTorch推荐在大多数情况下使用 align_corners=False

转换成ONNX模型

让我们用下面的代码来把 PyTorch 的模型转换成 ONNX 格式的模型:

x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        "srcnn.onnx",
        opset_version=11,
        input_names=['input'],
        output_names=['output'])

其中,torch.onnx.export 是 PyTorch 自带的把模型转换成 ONNX 格式的函数。让我们先看一下前三个必选参数:前三个参数分别是要转换的模型、模型的任意一组输入、导出的 ONNX 文件的文件名。转换模型时,需要原模型和输出文件名是很容易理解的,但为什么需要为模型提供一组输入呢?这就涉及到 ONNX 转换的原理了。从 PyTorch 的模型到 ONNX 的模型,本质上是一种语言上的翻译。直觉上的想法是像编译器一样彻底解析原模型的代码,记录所有控制流。但前面也讲到,我们通常只用 ONNX 记录不考虑控制流的静态图。因此,PyTorch 提供了一种叫做追踪(trace)的模型转换方法:给定一组输入,再实际执行一遍模型,即把这组输入对应的计算图记录下来,保存为 ONNX 格式。export 函数用的就是追踪导出方法,需要给任意一组输入,让模型跑起来。我们的测试图片是三通道,256x256大小的,这里也构造一个同样形状的随机张量。

剩下的参数中,opset_version 表示 ONNX 算子集的版本。深度学习的发展会不断诞生新算子,为了支持这些新增的算子,ONNX会经常发布新的算子集,目前已经更新15个版本。我们令 opset_version = 11,即使用第11个 ONNX 算子集,是因为 SRCNN 中的 bicubic (双三次插值)在 opset11 中才得到支持。剩下的两个参数 input_names, output_names 是输入、输出 tensor 的名称,我们稍后会用到这些名称。

如果上述代码运行成功,目录下会新增一个"srcnn.onnx"的 ONNX 模型文件。我们可以用下面的脚本来验证一下模型文件是否正确。

import onnx

onnx_model = onnx.load("srcnn.onnx")
try:
    onnx.checker.check_model(onnx_model)
except Exception:
    print("Model incorrect")
else:
    print("Model correct")

其中,onnx.load 函数用于读取一个 ONNX 模型。onnx.checker.check_model 用于检查模型格式是否正确,如果有错误的话该函数会直接报错。我们的模型是正确的,控制台中应该会打印出"Model correct"。

接下来,让我们来看一看 ONNX 模型具体的结构是怎么样的。我们可以使用 Netron (开源的模型可视化工具)来可视化 ONNX 模型。把 srcnn.onnx 文件从本地的文件系统拖入网站,即可看到如下的可视化结果:

点击 input 或者 output,可以查看 ONNX 模型的基本信息,包括模型的版本信息,以及模型输入、输出的名称和数据类型。

点击某一个算子节点,可以看到算子的具体信息。比如点击第一个 Conv 可以看到:

现在,我们有了 SRCNN 的 ONNX 模型。让我们看看最后该如何把这个模型运行起来。

推理引擎 —— ONNX Runtime

ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,也就是我们前面提到的”推理引擎“。ONNX Runtime 是直接对接 ONNX 的,即 ONNX Runtime 可以直接读取并运行 .onnx 文件, 而不需要再把 .onnx 格式的文件转换成其他格式的文件。也就是说,对于 PyTorch - ONNX - ONNX Runtime 这条部署流水线,只要在目标设备中得到 .onnx 文件,并在 ONNX Runtime 上运行模型,模型部署就算大功告成了。

通过刚刚的操作,我们把 PyTorch 编写的模型转换成了 ONNX 模型,并通过可视化检查了模型的正确性。最后,让我们用 ONNX Runtime 运行一下模型,完成模型部署的最后一步。

ONNX Runtime 提供了 Python 接口。接着刚才的脚本,我们可以添加如下代码运行模型:

import onnxruntime

ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)

这段代码中,除去后处理操作外,和 ONNX Runtime 相关的代码只有三行。让我们简单解析一下这三行代码。onnxruntime.InferenceSession 用于获取一个 ONNX Runtime 推理器,其参数是用于推理的 ONNX 模型文件。推理器的 run 方法用于模型推理,其第一个参数为输出张量名的列表,第二个参数为输入值的字典。其中输入值字典的 key 为张量名,value 为 numpy 类型的张量值。输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应。

如果代码正常运行的话,另一幅超分辨率照片会保存在"face_ort.png"中。这幅图片和刚刚得到的"face_torch.png"是一模一样的。这说明 ONNX Runtime 成功运行了 SRCNN 模型,模型部署完成了!以后有用户想实现超分辨率的操作,我们只需要提供一个 "srcnn.onnx" 文件,并帮助用户配置好 ONNX Runtime 的 Python 环境,用几行代码就可以运行模型了。或者还有更简便的方法,我们可以利用 ONNX Runtime 编译出一个可以直接执行模型的应用程序。我们只需要给用户提供 ONNX 模型文件,并让用户在应用程序选择要执行的 ONNX 模型文件名就可以运行模型了。

模型部署中常见的难题

  • 模型的动态化。出于性能的考虑,各推理框架都默认模型的输入形状、输出形状、结构是静态的。而为了让模型的泛用性更强,部署时需要在尽可能不影响原有逻辑的前提下,让模型的输入输出或是结构动态化。

  • 新算子的实现。深度学习技术日新月异,提出新算子的速度往往快于 ONNX 维护者支持的速度。为了部署最新的模型,部署工程师往往需要自己在 ONNX 和推理引擎中支持新算子。

  • 中间表示与推理引擎的兼容问题。由于各推理引擎的实现不同,对 ONNX 难以形成统一的支持。为了确保模型在不同的推理引擎中有同样的运行效果,部署工程师往往得为某个推理引擎定制模型代码,这为模型部署引入了许多工作量。

问题:实现动态放大的超分辨率模型

在原来的 SRCNN 中,图片的放大比例是写死在模型里的:

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

...

def init_torch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)

我们使用 upscale_factor 来控制模型的放大比例。初始化模型的时候,我们默认令 upscale_factor 为 3,生成了一个放大 3 倍的 PyTorch 模型。这个 PyTorch 模型最终被转换成了 ONNX 格式的模型。如果我们需要一个放大 4 倍的模型,需要重新生成一遍模型,再做一次到 ONNX 的转换。

现在,假设我们要做一个超分辨率的应用。我们的用户希望图片的放大倍数能够自由设置。而我们交给用户的,只有一个 .onnx 文件和运行超分辨率模型的应用程序。我们在不修改 .onnx 文件的前提下改变放大倍数。

因此,我们必须修改原来的模型,令模型的放大倍数变成推理时的输入。在上面 Python 脚本的基础上,我们做一些修改,得到这样的脚本:

import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx
import cv2
import numpy as np


class SuperResolutionNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

        self.relu = nn.ReLU()

    def forward(self, x, upscale_factor):
        x = interpolate(x,
                        scale_factor=upscale_factor,
                        mode='bicubic',
                        align_corners=False)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out


def init_torch_model():
    torch_model = SuperResolutionNet()

    state_dict = torch.load('srcnn.pth')['state_dict']

    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model


model = init_torch_model()

input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img), 3).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch_2.png", torch_output)

SuperResolutionNet 未修改之前,nn.Upsample 在初始化阶段固化了放大倍数,而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。因此,我们在新脚本中使用 interpolate 代替 nn.Upsample,从而让模型支持动态放大倍数的超分。在第 55 行使用模型推理时,我们把放大倍数设置为 3。最后,图片保存在文件 "face_torch_2.png" 中。一切正常的话,"face_torch_2.png" 和 "face_torch.png" 的内容一模一样。

通过简单的修改,PyTorch 模型已经支持了动态分辨率。现在我们来尝试一下导出模型:

x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    torch.onnx.export(model, (x, 3),
                      "srcnn2.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])

运行这些脚本时,会报一长串错误。没办法,我们碰到了模型部署中的兼容性问题。

    return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)
TypeError: upsample_bicubic2d() received an invalid combination of arguments - got (Tensor, NoneType, bool, list), but expected one of:
 * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (Tensor, NoneType, bool, list of [Tensor, Tensor])
 * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h, float scales_w, *, Tensor out)

解决方法:自定义算子

直接使用 PyTorch 模型的话,我们修改几行代码就能实现模型输入的动态化。但在模型部署中,我们要花数倍的时间来设法解决这一问题。现在,让我们顺着解决问题的思路,体验一下模型部署的困难,并学习使用自定义算子的方式,解决超分辨率模型的动态化问题。

刚刚的报错是因为 PyTorch 模型在导出到 ONNX 模型时,模型的输入参数的类型必须全部是 torch.Tensor。而实际上我们传入的第二个参数" 3 "是一个整形变量。这不符合 PyTorch 转 ONNX 的规定。我们必须要修改一下原来的模型的输入。为了保证输入的所有参数都是 torch.Tensor 类型的,我们做如下修改:

...

class SuperResolutionNet(nn.Module):

    def forward(self, x, upscale_factor):
        x = interpolate(x,
                        scale_factor=upscale_factor.item(),
                        mode='bicubic',
                        align_corners=False)
        
...

# Inference
# Note that the second input is torch.tensor(3)
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy()

...

with torch.no_grad():
    torch.onnx.export(model, (x, torch.tensor(3)),
                      "srcnn2.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])

由于 PyTorch 中 interpolate 的 scale_factor 参数必须是一个数值,我们使用 torch.Tensor.item() 来把只有一个元素的 torch.Tensor 转换成数值。之后,在模型推理时,我们使用 torch.tensor(3) 代替 3,以使得我们的所有输入都满足要求。现在运行脚本的话,无论是直接运行模型,还是导出 ONNX 模型,都不会报错了。

但是,导出 ONNX 时却报了一条 TraceWarning 的警告。这条警告说有一些量可能会追踪失败。

TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale_factor=upscale_factor.item(),
============ Diagnostic Run torch.onnx.export version 2.0.0+nv23.05 ============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

这是怎么回事呢?让我们把生成的 srcnn2.onnx 用 Netron 可视化一下:

可以发现,虽然我们把模型推理的输入设置为了两个,但 ONNX 模型还是长得和原来一模一样,只有一个叫 " input " 的输入。这是由于我们使用了 torch.Tensor.item() 把数据从 Tensor 里取出来,而导出 ONNX 模型时这个操作是无法被记录的,只好报了一条 TraceWarning。这导致 interpolate 插值函数的放大倍数还是被设置成了" 3 "这个固定值,我们导出的" srcnn2.onnx "和最开始的" srcnn.onnx "完全相同。

直接修改原来的模型似乎行不通,我们得从 PyTorch 转 ONNX 的原理入手,强行令 ONNX 模型明白我们的想法了。

仔细观察 Netron 上可视化出的 ONNX 模型,可以发现在 PyTorch 中无论是使用最早的 nn.Upsample,还是后来的 interpolate,PyTorch 里的插值操作最后都会转换成 ONNX 定义的 Resize 操作。也就是说,所谓 PyTorch 转 ONNX,实际上就是把每个 PyTorch 的操作映射成了 ONNX 定义的算子。

点击该算子,可以看到它的详细参数如下:

其中,展开 scales,可以看到 scales 是一个长度为 4 的一维张量,其内容为 [1, 1, 3, 3], 表示 Resize 操作每一个维度的缩放系数;其类型为 Initializer,表示这个值是根据常量直接初始化出来的。如果我们能够自己生成一个 ONNX 的 Resize 算子,让 scales 成为一个可变量而不是常量,就像它上面的 X 一样,那这个超分辨率模型就能动态缩放了。

现有实现插值的 PyTorch 算子有一套规定好的映射到 ONNX Resize 算子的方法,这些映射出的 Resize 算子的 scales 只能是常量,无法满足我们的需求。我们得自己定义一个实现插值的 PyTorch 算子,然后让它映射到一个我们期望的 ONNX Resize 算子上。

下面的脚本定义了一个 PyTorch 插值算子,并在模型里使用了它。我们先通过运行模型来验证该算子的正确性:

import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx
import cv2
import numpy as np


class NewInterpolate(torch.autograd.Function):

    @staticmethod
    def symbolic(g, input, scales):
        return g.op("Resize",
                    input,
                    g.op("Constant",
                         value_t=torch.tensor([], dtype=torch.float32)),
                    scales,
                    coordinate_transformation_mode_s="pytorch_half_pixel",
                    cubic_coeff_a_f=-0.75,
                    mode_s='cubic',
                    nearest_mode_s="floor")

    @staticmethod
    def forward(ctx, input, scales):
        scales = scales.tolist()[-2:]
        return interpolate(input,
                           scale_factor=scales,
                           mode='bicubic',
                           align_corners=False)


class StrangeSuperResolutionNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

        self.relu = nn.ReLU()

    def forward(self, x, upscale_factor):
        x = NewInterpolate.apply(x, upscale_factor)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out


def init_torch_model():
    torch_model = StrangeSuperResolutionNet()

    state_dict = torch.load('srcnn.pth')['state_dict']

    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model


model = init_torch_model()
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float)

input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img), factor).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch_3.png", torch_output)

模型运行正常的话,一幅放大3倍的超分辨率图片会保存在"face_torch_3.png"中,其内容和"face_torch.png"完全相同。

在刚刚那个脚本中,我们定义 PyTorch 插值算子的代码如下:

class NewInterpolate(torch.autograd.Function):

    @staticmethod
    def symbolic(g, input, scales):
        return g.op("Resize",
                    input,
                    g.op("Constant",
                         value_t=torch.tensor([], dtype=torch.float32)),
                    scales,
                    coordinate_transformation_mode_s="pytorch_half_pixel",
                    cubic_coeff_a_f=-0.75,
                    mode_s='cubic',
                    nearest_mode_s="floor")

    @staticmethod
    def forward(ctx, input, scales):
        scales = scales.tolist()[-2:]
        return interpolate(input,
                           scale_factor=scales,
                           mode='bicubic',
                           align_corners=False)

在具体介绍这个算子的实现前,让我们先理清一下思路。我们希望新的插值算子有两个输入,一个是被用于操作的图像,一个是图像的放缩比例。前面讲到,为了对接 ONNX 中 Resize 算子的 scales 参数,这个放缩比例是一个 [1, 1, x, x] 的张量,其中 x 为放大倍数。在之前放大3倍的模型中,这个参数被固定成了[1, 1, 3, 3]。因此,在插值算子中,我们希望模型的第二个输入是一个 [1, 1, w, h] 的张量,其中 w 和 h 分别是图片宽和高的放大倍数。

搞清楚了插值算子的输入,再看一看算子的具体实现。算子的推理行为由算子的 foward 方法决定。该方法的第一个参数必须为 ctx,后面的参数为算子的自定义输入,我们设置两个输入,分别为被操作的图像和放缩比例。为保证推理正确,需要把 [1, 1, w, h] 格式的输入对接到原来的 interpolate 函数上。我们的做法是截取输入张量的后两个元素,把这两个元素以 list 的格式传入 interpolate 的 scale_factor 参数。

接下来,我们要决定新算子映射到 ONNX 算子的方法。映射到 ONNX 的方法由一个算子的 symbolic 方法决定。symbolic 方法第一个参数必须是g,之后的参数是算子的自定义输入,和 forward 函数一样。ONNX 算子的具体定义由 g.op 实现。g.op 的每个参数都可以映射到 ONNX 中的算子属性

对于其他参数,我们可以照着现在的 Resize 算子填。而要注意的是,我们现在希望 scales 参数是由输入动态决定的。因此,在填入 ONNX 的 scales 时,我们要把 symbolic 方法的输入参数中的 scales 填入。

g.op浅析

g.op 用于将 PyTorch 的操作符映射到 ONNX 的操作符。

g.op 的基本用法

g.op(op_name, *inputs, **attrs)

关键参数详解:

  1. op_name(字符串):

    • 这是 ONNX 操作符的名称。例如,对于卷积操作,ONNX 使用的是 "Conv";对于上采样,使用的是 "Resize"

  2. *inputs(张量):

    • 这些是操作符的输入,可以是张量或其他 ONNX 操作符的输出。

    • 如果操作符有多个输入,可以依次传递多个输入张量。

  3. **attrs(关键字参数):

    • 这些是操作符的属性。每个 ONNX 操作符都有不同的属性,这些属性控制操作符的具体行为。属性的类型可以是整数、浮点数、字符串、张量或列表。

1. 简单的操作符

假设我们要将 PyTorch 中的 Add 操作导出为 ONNX 中的 Add 操作符,可以通过以下方式定义:

def symbolic(g, input1, input2):
    return g.op("Add", input1, input2)

这个函数将 PyTorch 的两个输入张量 input1input2 映射为 ONNX 的 Add 操作符。这里没有特殊的属性。

2. 带属性的操作符

有些操作符除了输入外,还需要定义一些属性。例如,卷积(Conv)操作符有卷积核的大小、步长、填充等属性:

def symbolic(g, input, weight, bias=None):
    return g.op("Conv", input, weight,
                bias if bias is not None else g.constant(0),
                kernel_shape_i=(3, 3),  # 卷积核大小
                strides_i=(1, 1),       # 步长
                pads_i=(1, 1, 1, 1))    # 填充

在这个例子中,g.op("Conv", ...) 定义了一个卷积操作:

  • input 是输入张量。

  • weight 是卷积核权重。

  • bias 是可选的,如果没有提供则使用常数 0。

  • 属性部分

    • kernel_shape_i:卷积核的大小,这里是 3x3

    • strides_i:步长,这里是 (1, 1)

    • pads_i:填充大小,这里指定了四个维度的填充量 (1, 1, 1, 1)

注意:

  • 属性名称后面的 _i,表示这是一个整型的属性。PyTorch 中会根据类型对属性进行不同的后缀标记:

    • _i:整型(integer)

    • _f:浮点型(float)

    • _s:字符串(string)

    • _t:张量(tensor)

3. 带常量输入的操作符

如果操作符需要常量输入,比如 Resize 操作中的 roi 是常量,可以通过 g.op("Constant", ...) 来定义常量输入:

def symbolic(g, input, scales):
    roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))  # 定义一个空的 roi 常量
    return g.op("Resize", input, roi, scales, mode_s="nearest")

在这里,roi 被定义为一个常量,并作为 Resize 操作的第二个输入。g.op("Constant", ...) 表示一个常量节点,value_t 是张量常量的值。

接着,让我们把新模型导出成 ONNX 模型:

x = torch.randn(1, 3, 256, 256)
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float)

with torch.no_grad():
    torch.onnx.export(model, (x, factor),
                      "srcnn3.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])

可以看到,正如我们所期望的,导出的 ONNX 模型有了两个输入!第二个输入表示图像的放缩比例。

之前在验证 PyTorch 模型和导出 ONNX 模型时,我们宽高的缩放比例设置成了 3x3。现在,在用 ONNX Runtime 推理时,我们尝试使用 4x4 的缩放比例:

import onnxruntime

input_factor = np.array([1, 1, 4, 4], dtype=np.float32)
ort_session = onnxruntime.InferenceSession("srcnn3.onnx")
ort_inputs = {'input': input_img, 'factor': input_factor}
ort_output = ort_session.run(None, ort_inputs)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort_3.png", ort_output)

运行上面的代码,可以得到一个边长放大4倍的超分辨率图片 "face_ort_3.png"。动态的超分辨率模型生成成功了!只要修改 input_factor,我们就可以自由地控制图片的缩放比例。

事实上,我们不仅可以在现有的 ONNX 算子的基础上进行改造,还可以定义新的 ONNX 算子以拓展 ONNX 的表达能力。

Last updated