> For the complete documentation index, see [llms.txt](https://deployment.gitbook.io/love/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://deployment.gitbook.io/love/whitepaper/onnx_intro/0x00.md).

# 0x00自定义算子

### 创建Pytorch模型

用下面的代码来创建一个经典的超分辨率模型 SRCNN

```python
import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out
        
# Download checkpoint and test image
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
    'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url, name in zip(urls, names):
    if not os.path.exists(name):
        open(name, 'wb').write(requests.get(url).content)

def init_torch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)
    
    state_dict = torch.load('srcnn.pth')['state_dict']
    
    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model

model = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img)).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch.png", torch_output)
```

SRCNN 先把图像上采样到对应分辨率，再用 3 个卷积层处理图像。为了方便起见，我们跳过训练网络的步骤，直接下载模型权重（由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样，我们修改了权重字典的 key 来适配我们定义的模型），同时下载好输入图片。<mark style="color:red;">为了让模型输出成正确的图片格式，我们把模型的输出转换成 HWC 格式，并保证每一通道的颜色值都在 0\~255 之间。</mark>如果脚本正常运行的话，一幅超分辨率的人脸照片会保存在 `“face_torch.png”` 中。

<figure><img src="/files/jvVyMuQrrOx8NoqDiWT1" alt=""><figcaption></figcaption></figure>

在 PyTorch 模型测试正确后，我们来正式开始部署这个模型。我们下一步的任务是把 PyTorch 模型转换成用中间表示 ONNX 描述的模型。

> 假设 `state_dict` 中的键如下：
>
> ```python
> {
>     'module.conv1.weight': tensor(...),
>     'module.conv1.bias': tensor(...),
>     'module.fc1.weight': tensor(...),
>     'module.fc1.bias': tensor(...),
> }
> ```
>
> 经过
>
> ```python
> # Adapt the checkpoint
>     for old_key in list(state_dict.keys()):
>         new_key = '.'.join(old_key.split('.')[1:])
>         state_dict[new_key] = state_dict.pop(old_key)
> ```
>
> 这段代码后，`state_dict` 会变成：
>
> ```python
> {
>     'conv1.weight': tensor(...),
>     'conv1.bias': tensor(...),
>     'fc1.weight': tensor(...),
>     'fc1.bias': tensor(...),
> }
> ```
>
> 键名中的 `'module.'` 前缀被移除了。

#### **构造函数**

```python
torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
```

**参数解释：**

1. **`size`**（`tuple` or `int`, optional）：
   * 这个参数指定输出张量的目标尺寸。可以传入一个整数（对所有维度进行相同的上采样）或一个元组，元组的长度应等于输入张量的空间维度数。例如，对于图像数据，元组通常是 `(H_out, W_out)`。
   * 如果指定了 `size`，就会强制将输出张量的尺寸调整为该值。
2. **`scale_factor`**（`float` or `tuple` of `float`, optional）：
   * 表示上采样的倍率，即输入张量在每个维度上扩大多少倍。如果传入一个浮点数，它将作用于每个维度。如果传入一个元组，则每个元素分别对应每个空间维度的缩放系数。
   * 通常会指定 `scale_factor` 来表示整体的放大倍数。例如，`scale_factor=2` 会让输入张量的每个空间维度放大两倍。
3. **`mode`**（`str`, optional, default=`'nearest'`）：
   * 用于指定上采样时使用的插值方法。PyTorch 提供了以下几种插值方式：
     * `'nearest'`：最近邻插值。
     * `'linear'`：线性插值（适用于 3D 张量）。
     * `'bilinear'`：双线性插值（适用于 4D 张量，如图像数据）。
     * `'bicubic'`：双三次插值（仅适用于 4D 张量）。
     * `'trilinear'`：三线性插值（适用于 5D 张量，如 3D 体积数据）。
     * `'area'`：区域插值，根据输入大小和输出大小执行平均池化。
4. **`align_corners`**（`bool`, optional, default=`None`）：
   * 该参数用于控制当插值时如何对齐角点。对于 `bilinear`、`trilinear` 和 `bicubic` 插值方式，它控制插值的行为。
   * 如果 `align_corners=True`，输入张量的角像素会与输出张量的角像素对齐，避免边界像素的平滑效果。
   * 如果 `align_corners=False`，输入和输出张量的角像素不对齐，默认会使得插值的效果更加平滑。通常推荐将其设为 `False`。
   * 对于 `nearest` 插值，该参数没有作用。
5. **`recompute_scale_factor`**（`bool`, optional, default=`None`）：
   * 控制是否在计算输出大小时重新计算缩放因子（scale factor）。通常，PyTorch会将 `scale_factor` 直接用于大小计算，但在某些情况下，如果需要更准确的输出尺寸，可以设置 `recompute_scale_factor=True`。

**注意：**

* `size` 和 `scale_factor` 不能同时指定；必须二选一。
* `mode` 指定插值方法，默认是最近邻插值，插值方式会影响到生成的张量的平滑度与细节恢复。

#### **上采样的工作原理**

上采样的基本思想是通过插值或复制来扩大输入张量的空间尺寸。最常见的插值方法有：

* **最近邻插值**（`nearest`）：简单地将最近的像素值复制到更大的空间中，不会产生平滑过渡。速度较快，通常用于快速上采样。
* **双线性插值**（`bilinear`）：基于相邻像素的线性关系插值，生成平滑的过渡效果。通常在处理图像时使用，能够提供更平滑的上采样效果。
* **双三次插值**（`bicubic`）：考虑更大范围的像素进行插值，生成更细腻的上采样效果，代价是计算复杂度稍高。
* **三线性插值**（`trilinear`）：用于5维数据（如3D体积数据）的插值方式。

#### **`align_corners` 的解释与使用**

`align_corners` 控制插值时角点的对齐方式，它主要影响到双线性、三线性和双三次插值。该参数的主要影响体现在输出结果的平滑程度上。

* **`align_corners=True`**：会将输入张量的角点像素对齐到输出张量的角点位置，这可能会造成插值结果在边缘的像素值变化较大。

  示例：

  ```python
  upsample_bilinear = nn.Upsample(size=(50, 50), mode='bilinear', align_corners=True)
  ```
* **`align_corners=False`**（默认）：输入和输出张量的角点不对齐。这个选项通常能提供更平滑的插值效果，并且PyTorch推荐在大多数情况下使用 `align_corners=False`。

### 转换成ONNX模型

让我们用下面的代码来把 PyTorch 的模型转换成 ONNX 格式的模型：

```python
x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        "srcnn.onnx",
        opset_version=11,
        input_names=['input'],
        output_names=['output'])
```

其中，**torch.onnx.export** 是 PyTorch 自带的把模型转换成 ONNX 格式的函数。让我们先看一下前三个必选参数：前三个参数分别是要转换的模型、模型的任意一组输入、导出的 ONNX 文件的文件名。转换模型时，需要原模型和输出文件名是很容易理解的，但为什么需要为模型提供一组输入呢？这就涉及到 ONNX 转换的原理了。<mark style="color:red;">从 PyTorch 的模型到 ONNX 的模型，本质上是一种语言上的翻译。直觉上的想法是像编译器一样彻底解析原模型的代码，记录所有控制流。但前面也讲到，我们通常只用 ONNX 记录不考虑控制流的静态图。因此，PyTorch 提供了一种叫做追踪（trace）的模型转换方法：给定一组输入，再实际执行一遍模型，即把这组输入对应的计算图记录下来，保存为 ONNX 格式。</mark>export 函数用的就是追踪导出方法，需要给任意一组输入，让模型跑起来。我们的测试图片是三通道，256x256大小的，这里也构造一个同样形状的随机张量。

剩下的参数中，opset\_version 表示 ONNX 算子集的版本。深度学习的发展会不断诞生新算子，为了支持这些新增的算子，ONNX会经常发布新的算子集，目前已经更新15个版本。我们令 opset\_version = 11，即使用第11个 ONNX 算子集，是因为 SRCNN 中的 bicubic （双三次插值）在 opset11 中才得到支持。剩下的两个参数 input\_names, output\_names 是输入、输出 tensor 的名称，我们稍后会用到这些名称。

如果上述代码运行成功，目录下会新增一个"srcnn.onnx"的 ONNX 模型文件。我们可以用下面的脚本来验证一下模型文件是否正确。

```python
import onnx

onnx_model = onnx.load("srcnn.onnx")
try:
    onnx.checker.check_model(onnx_model)
except Exception:
    print("Model incorrect")
else:
    print("Model correct")
```

其中，**onnx.load** 函数用于读取一个 ONNX 模型。<mark style="color:red;">**onnx.checker.check\_model**</mark> <mark style="color:red;"></mark><mark style="color:red;">用于检查模型格式是否正确，如果有错误的话该函数会直接报错。</mark>我们的模型是正确的，控制台中应该会打印出"Model correct"。

接下来，让我们来看一看 ONNX 模型具体的结构是怎么样的。我们可以使用 **Netron** （开源的模型可视化工具）来可视化 ONNX 模型。把 srcnn.onnx 文件从本地的文件系统拖入网站，即可看到如下的可视化结果：

<figure><img src="/files/YlI7RpKY6fV2Y7kigAh1" alt=""><figcaption></figcaption></figure>

<mark style="color:red;">点击 input 或者 output，可以查看 ONNX 模型的基本信息，包括模型的版本信息，以及模型输入、输出的名称和数据类型。</mark>

<figure><img src="/files/mAxLPPE0zHbgNThoYO79" alt=""><figcaption></figcaption></figure>

<mark style="color:red;">点击某一个算子节点，可以看到算子的具体信息。比如点击第一个 Conv 可以看到：</mark>

<figure><img src="/files/wVYnXuZgHjaWz42xVwh2" alt=""><figcaption></figcaption></figure>

现在，我们有了 SRCNN 的 ONNX 模型。让我们看看最后该如何把这个模型运行起来。

### **推理引擎 —— ONNX Runtime**

**ONNX Runtime** 是由微软维护的一个跨平台机器学习推理加速器，也就是我们前面提到的”推理引擎“。ONNX Runtime 是直接对接 ONNX 的，即 ONNX Runtime 可以直接读取并运行 .onnx 文件, 而不需要再把 .onnx 格式的文件转换成其他格式的文件。也就是说，对于 `PyTorch - ONNX - ONNX Runtime` 这条部署流水线，只要在目标设备中得到 .onnx 文件，并在 `ONNX Runtime` 上运行模型，模型部署就算大功告成了。

<mark style="color:red;">通过刚刚的操作，我们把 PyTorch 编写的模型转换成了 ONNX 模型，并通过可视化检查了模型的正确性。最后，让我们用 ONNX Runtime 运行一下模型，完成模型部署的最后一步。</mark>

ONNX Runtime 提供了 Python 接口。接着刚才的脚本，我们可以添加如下代码运行模型：

```python
import onnxruntime

ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
```

这段代码中，除去后处理操作外，和 ONNX Runtime 相关的代码只有三行。让我们简单解析一下这三行代码。**onnxruntime.InferenceSession** 用于获取一个 ONNX Runtime 推理器，其参数是用于推理的 ONNX 模型文件。<mark style="color:red;">**推理器的 run 方法用于模型推理，其第一个参数为输出张量名的列表，第二个参数为输入值的字典。其中输入值字典的 key 为张量名，value 为 numpy 类型的张量值。输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应。**</mark>

如果代码正常运行的话，另一幅超分辨率照片会保存在`"face_ort.png"`中。这幅图片和刚刚得到的`"face_torch.png"`是一模一样的。这说明 ONNX Runtime 成功运行了 SRCNN 模型，模型部署完成了！以后有用户想实现超分辨率的操作，我们只需要提供一个 "srcnn.onnx" 文件，并帮助用户配置好 ONNX Runtime 的 Python 环境，用几行代码就可以运行模型了。或者还有更简便的方法，我们可以利用 ONNX Runtime 编译出一个可以直接执行模型的应用程序。我们只需要给用户提供 ONNX 模型文件，并让用户在应用程序选择要执行的 ONNX 模型文件名就可以运行模型了。

### **模型部署中常见的难题**

* **模型的动态化。**&#x51FA;于性能的考虑，各推理框架都默认模型的输入形状、输出形状、结构是静态的。而为了让模型的泛用性更强，部署时需要在尽可能不影响原有逻辑的前提下，让模型的输入输出或是结构动态化。
* **新算子的实现。**&#x6DF1;度学习技术日新月异，提出新算子的速度往往快于 ONNX 维护者支持的速度。为了部署最新的模型，部署工程师往往需要自己在 ONNX 和推理引擎中支持新算子。
* **中间表示与推理引擎的兼容问题。**&#x7531;于各推理引擎的实现不同，对 ONNX 难以形成统一的支持。为了确保模型在不同的推理引擎中有同样的运行效果，部署工程师往往得为某个推理引擎定制模型代码，这为模型部署引入了许多工作量。

### **问题：实现动态放大的超分辨率模型**

在原来的 SRCNN 中，图片的放大比例是写死在模型里的：

```python
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

...

def init_torch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)
```

我们使用 upscale\_factor 来控制模型的放大比例。初始化模型的时候，我们默认令 upscale\_factor 为 3，生成了一个放大 3 倍的 PyTorch 模型。这个 PyTorch 模型最终被转换成了 ONNX 格式的模型。如果我们需要一个放大 4 倍的模型，需要重新生成一遍模型，再做一次到 ONNX 的转换。

现在，假设我们要做一个超分辨率的应用。我们的用户希望图片的放大倍数能够自由设置。而我们交给用户的，只有一个 .onnx 文件和运行超分辨率模型的应用程序。我们在不修改 .onnx 文件的前提下改变放大倍数。

因此，我们必须修改原来的模型，令模型的放大倍数变成推理时的输入。在上面 Python 脚本的基础上，我们做一些修改，得到这样的脚本：

```python
import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx
import cv2
import numpy as np


class SuperResolutionNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

        self.relu = nn.ReLU()

    def forward(self, x, upscale_factor):
        x = interpolate(x,
                        scale_factor=upscale_factor,
                        mode='bicubic',
                        align_corners=False)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out


def init_torch_model():
    torch_model = SuperResolutionNet()

    state_dict = torch.load('srcnn.pth')['state_dict']

    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model


model = init_torch_model()

input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img), 3).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch_2.png", torch_output)
```

SuperResolutionNet 未修改之前，nn.Upsample 在初始化阶段固化了放大倍数，而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。因此，我们在新脚本中使用 interpolate 代替 nn.Upsample，从而让模型支持动态放大倍数的超分。在第 55 行使用模型推理时，我们把放大倍数设置为 3。最后，图片保存在文件 "face\_torch\_2.png" 中。一切正常的话，"face\_torch\_2.png" 和 "face\_torch.png" 的内容一模一样。

通过简单的修改，PyTorch 模型已经支持了动态分辨率。现在我们来尝试一下导出模型：

```python
x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    torch.onnx.export(model, (x, 3),
                      "srcnn2.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])
```

运行这些脚本时，会报一长串错误。没办法，我们碰到了模型部署中的兼容性问题。

```bash
    return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)
TypeError: upsample_bicubic2d() received an invalid combination of arguments - got (Tensor, NoneType, bool, list), but expected one of:
 * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (Tensor, NoneType, bool, list of [Tensor, Tensor])
 * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h, float scales_w, *, Tensor out)
```

### **解决方法：自定义算子**

直接使用 PyTorch 模型的话，我们修改几行代码就能实现模型输入的动态化。但在模型部署中，我们要花数倍的时间来设法解决这一问题。现在，让我们顺着解决问题的思路，体验一下模型部署的困难，并学习使用自定义算子的方式，解决超分辨率模型的动态化问题。

<mark style="color:red;">刚刚的报错是因为 PyTorch 模型在导出到 ONNX 模型时，模型的输入参数的类型必须全部是 torch.Tensor。而实际上我们传入的第二个参数" 3 "是一个整形变量。这不符合 PyTorch 转 ONNX 的规定。</mark>我们必须要修改一下原来的模型的输入。为了保证输入的所有参数都是 torch.Tensor 类型的，我们做如下修改：

```python
...

class SuperResolutionNet(nn.Module):

    def forward(self, x, upscale_factor):
        x = interpolate(x,
                        scale_factor=upscale_factor.item(),
                        mode='bicubic',
                        align_corners=False)
        
...

# Inference
# Note that the second input is torch.tensor(3)
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy()

...

with torch.no_grad():
    torch.onnx.export(model, (x, torch.tensor(3)),
                      "srcnn2.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])
```

<mark style="color:red;">由于 PyTorch 中 interpolate 的 scale\_factor 参数必须是一个数值，我们使用 torch.Tensor.item() 来把只有一个元素的 torch.Tensor 转换成数值。</mark>之后，在模型推理时，我们使用 torch.tensor(3) 代替 3，以使得我们的所有输入都满足要求。现在运行脚本的话，无论是直接运行模型，还是导出 ONNX 模型，都不会报错了。

但是，导出 ONNX 时却报了一条 TraceWarning 的警告。这条警告说有一些量可能会追踪失败。

```bash
TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale_factor=upscale_factor.item(),
============ Diagnostic Run torch.onnx.export version 2.0.0+nv23.05 ============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
```

这是怎么回事呢？让我们把生成的 srcnn2.onnx 用 Netron 可视化一下：

<figure><img src="/files/2JML83bKESuZQi2vXuuM" alt="" width="375"><figcaption></figcaption></figure>

可以发现，虽然我们把模型推理的输入设置为了两个，但 ONNX 模型还是长得和原来一模一样，只有一个叫 " input " 的输入。<mark style="color:red;">这是由于我们使用了 torch.Tensor.item() 把数据从 Tensor 里取出来，而导出 ONNX 模型时这个操作是无法被记录的，只好报了一条 TraceWarning。</mark>这导致 interpolate 插值函数的放大倍数还是被设置成了" 3 "这个固定值，我们导出的" srcnn2.onnx "和最开始的" srcnn.onnx "完全相同。

直接修改原来的模型似乎行不通，我们得从 PyTorch 转 ONNX 的原理入手，强行令 ONNX 模型明白我们的想法了。

仔细观察 Netron 上可视化出的 ONNX 模型，可以发现在 PyTorch 中无论是使用最早的 nn.Upsample，还是后来的 interpolate，PyTorch 里的插值操作最后都会转换成 ONNX 定义的 Resize 操作。也就是说，所谓 PyTorch 转 ONNX，实际上就是把每个 PyTorch 的操作映射成了 ONNX 定义的算子。

点击该算子，可以看到它的详细参数如下：

<figure><img src="/files/6LJxCals6xFqTj4g2MZW" alt="" width="563"><figcaption></figcaption></figure>

其中，展开 scales，可以看到 scales 是一个长度为 4 的一维张量，其内容为 \[1, 1, 3, 3], 表示 Resize  操作每一个维度的缩放系数；其类型为 Initializer，表示这个值是根据常量直接初始化出来的。<mark style="color:red;">如果我们能够自己生成一个 ONNX 的 Resize 算子，让 scales 成为一个可变量而不是常量，就像它上面的 X 一样，那这个超分辨率模型就能动态缩放了。</mark>

<mark style="color:red;">现有实现插值的 PyTorch 算子有一套规定好的映射到 ONNX Resize 算子的方法，这些映射出的 Resize 算子的 scales 只能是常量，无法满足我们的需求。我们得自己定义一个实现插值的 PyTorch 算子，然后让它映射到一个我们期望的 ONNX Resize 算子上。</mark>

下面的脚本定义了一个 PyTorch 插值算子，并在模型里使用了它。我们先通过运行模型来验证该算子的正确性：

```python
import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx
import cv2
import numpy as np


class NewInterpolate(torch.autograd.Function):

    @staticmethod
    def symbolic(g, input, scales):
        return g.op("Resize",
                    input,
                    g.op("Constant",
                         value_t=torch.tensor([], dtype=torch.float32)),
                    scales,
                    coordinate_transformation_mode_s="pytorch_half_pixel",
                    cubic_coeff_a_f=-0.75,
                    mode_s='cubic',
                    nearest_mode_s="floor")

    @staticmethod
    def forward(ctx, input, scales):
        scales = scales.tolist()[-2:]
        return interpolate(input,
                           scale_factor=scales,
                           mode='bicubic',
                           align_corners=False)


class StrangeSuperResolutionNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

        self.relu = nn.ReLU()

    def forward(self, x, upscale_factor):
        x = NewInterpolate.apply(x, upscale_factor)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out


def init_torch_model():
    torch_model = StrangeSuperResolutionNet()

    state_dict = torch.load('srcnn.pth')['state_dict']

    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model


model = init_torch_model()
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float)

input_img = cv2.imread('face.png').astype(np.float32)

# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference
torch_output = model(torch.from_numpy(input_img), factor).detach().numpy()

# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

# Show image
cv2.imwrite("face_torch_3.png", torch_output)
```

模型运行正常的话，一幅放大3倍的超分辨率图片会保存在"face\_torch\_3.png"中，其内容和"face\_torch.png"完全相同。

在刚刚那个脚本中，我们定义 PyTorch 插值算子的代码如下：

```python
class NewInterpolate(torch.autograd.Function):

    @staticmethod
    def symbolic(g, input, scales):
        return g.op("Resize",
                    input,
                    g.op("Constant",
                         value_t=torch.tensor([], dtype=torch.float32)),
                    scales,
                    coordinate_transformation_mode_s="pytorch_half_pixel",
                    cubic_coeff_a_f=-0.75,
                    mode_s='cubic',
                    nearest_mode_s="floor")

    @staticmethod
    def forward(ctx, input, scales):
        scales = scales.tolist()[-2:]
        return interpolate(input,
                           scale_factor=scales,
                           mode='bicubic',
                           align_corners=False)
```

在具体介绍这个算子的实现前，让我们先理清一下思路。<mark style="color:red;">我们希望新的插值算子有两个输入，一个是被用于操作的图像，一个是图像的放缩比例。</mark>前面讲到，为了对接 ONNX 中 Resize 算子的 scales 参数，这个放缩比例是一个 \[1, 1, x, x] 的张量，其中 x 为放大倍数。在之前放大3倍的模型中，这个参数被固定成了\[1, 1, 3, 3]。因此，在插值算子中，我们希望模型的第二个输入是一个 \[1, 1, w, h] 的张量，其中 w 和 h 分别是图片宽和高的放大倍数。

搞清楚了插值算子的输入，再看一看算子的具体实现。<mark style="color:red;">算子的推理行为由算子的 foward 方法决定。该方法的第一个参数必须为 ctx，后面的参数为算子的自定义输入</mark>，我们设置两个输入，分别为被操作的图像和放缩比例。为保证推理正确，需要把 \[1, 1, w, h] 格式的输入对接到原来的 interpolate 函数上。我们的做法是截取输入张量的后两个元素，把这两个元素以 list 的格式传入 interpolate 的 scale\_factor 参数。

接下来，<mark style="color:red;">我们要决定新算子映射到 ONNX 算子的方法。映射到 ONNX 的方法由一个算子的 symbolic 方法决定。symbolic 方法第一个参数必须是g</mark>，之后的参数是算子的自定义输入，和 forward 函数一样。<mark style="color:red;">ONNX 算子的具体定义由 g.op 实现。g.op 的每个参数都可以映射到 ONNX 中的算子属性</mark>：

<figure><img src="/files/8OzF8NWZkkEfHKUeyEFI" alt="" width="563"><figcaption></figcaption></figure>

对于其他参数，我们可以照着现在的 Resize 算子填。而要注意的是，我们现在希望 scales 参数是由输入动态决定的。因此，在填入 ONNX 的 scales 时，我们要把 symbolic 方法的输入参数中的 scales 填入。

#### g.op浅析

`g.op` 用于将 PyTorch 的操作符映射到 ONNX 的操作符。

`g.op` 的基本用法

```python
g.op(op_name, *inputs, **attrs)
```

**关键参数详解：**

1. **`op_name`**（字符串）：
   * 这是 ONNX 操作符的名称。例如，对于卷积操作，ONNX 使用的是 `"Conv"`；对于上采样，使用的是 `"Resize"`。
2. **`*inputs`**（张量）：
   * 这些是操作符的输入，可以是张量或其他 ONNX 操作符的输出。
   * 如果操作符有多个输入，可以依次传递多个输入张量。
3. **`**attrs`**（关键字参数）：
   * 这些是操作符的属性。每个 ONNX 操作符都有不同的属性，这些属性控制操作符的具体行为。属性的类型可以是整数、浮点数、字符串、张量或列表。

**1. 简单的操作符**

假设我们要将 PyTorch 中的 `Add` 操作导出为 ONNX 中的 `Add` 操作符，可以通过以下方式定义：

```python
def symbolic(g, input1, input2):
    return g.op("Add", input1, input2)
```

这个函数将 PyTorch 的两个输入张量 `input1` 和 `input2` 映射为 ONNX 的 `Add` 操作符。这里没有特殊的属性。

**2. 带属性的操作符**

有些操作符除了输入外，还需要定义一些属性。例如，卷积（`Conv`）操作符有卷积核的大小、步长、填充等属性：

```python
def symbolic(g, input, weight, bias=None):
    return g.op("Conv", input, weight,
                bias if bias is not None else g.constant(0),
                kernel_shape_i=(3, 3),  # 卷积核大小
                strides_i=(1, 1),       # 步长
                pads_i=(1, 1, 1, 1))    # 填充
```

在这个例子中，`g.op("Conv", ...)` 定义了一个卷积操作：

* `input` 是输入张量。
* `weight` 是卷积核权重。
* `bias` 是可选的，如果没有提供则使用常数 0。
* **属性部分**：
  * `kernel_shape_i`：卷积核的大小，这里是 `3x3`。
  * `strides_i`：步长，这里是 `(1, 1)`。
  * `pads_i`：填充大小，这里指定了四个维度的填充量 `(1, 1, 1, 1)`。

注意：

* 属性名称后面的 `_i`，表示这是一个整型的属性。PyTorch 中会根据类型对属性进行不同的后缀标记：
  * `_i`：整型（integer）
  * `_f`：浮点型（float）
  * `_s`：字符串（string）
  * `_t`：张量（tensor）

**3. 带常量输入的操作符**

如果操作符需要常量输入，比如 `Resize` 操作中的 `roi` 是常量，可以通过 `g.op("Constant", ...)` 来定义常量输入：

```python
def symbolic(g, input, scales):
    roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))  # 定义一个空的 roi 常量
    return g.op("Resize", input, roi, scales, mode_s="nearest")
```

在这里，`roi` 被定义为一个常量，并作为 `Resize` 操作的第二个输入。`g.op("Constant", ...)` 表示一个常量节点，`value_t` 是张量常量的值。

接着，让我们把新模型导出成 ONNX 模型：

```python
x = torch.randn(1, 3, 256, 256)
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float)

with torch.no_grad():
    torch.onnx.export(model, (x, factor),
                      "srcnn3.onnx",
                      opset_version=11,
                      input_names=['input', 'factor'],
                      output_names=['output'])
```

<figure><img src="/files/0VW9ndLSfYbTBCdSn0jp" alt=""><figcaption></figcaption></figure>

可以看到，正如我们所期望的，导出的 ONNX 模型有了两个输入！第二个输入表示图像的放缩比例。

之前在验证 PyTorch 模型和导出 ONNX 模型时，我们宽高的缩放比例设置成了 3x3。现在，在用 ONNX Runtime 推理时，我们尝试使用 4x4 的缩放比例：

```python
import onnxruntime

input_factor = np.array([1, 1, 4, 4], dtype=np.float32)
ort_session = onnxruntime.InferenceSession("srcnn3.onnx")
ort_inputs = {'input': input_img, 'factor': input_factor}
ort_output = ort_session.run(None, ort_inputs)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort_3.png", ort_output)
```

运行上面的代码，可以得到一个边长放大4倍的超分辨率图片 "face\_ort\_3.png"。动态的超分辨率模型生成成功了！只要修改 input\_factor，我们就可以自由地控制图片的缩放比例。

> 事实上，我们不仅可以在现有的 ONNX 算子的基础上进行改造，还可以定义新的 ONNX 算子以拓展 ONNX 的表达能力。
