81. 深度学习模型推理和部署#

81.1. 介绍#

深度学习模型推理和部署是工程应用过程中非常重要的一块内容。本次实验中,我们将了解 TensorFlow 官方提供的 TensorFlow Serving 架构,以及学习并使用 ONNX 神经网络模型开放格式。

81.2. 知识点#

  • TensorFlow Serving

  • ONNX 开放模型格式

前面,我们学习过 scikit-learn 模型的部署,并使用 Flask 构建了一个 RESTful API 以供推理调用。那么对于 TensorFlow 和 PyTorch 这类深度学习框架训练好的模型怎么部署和推理呢?

实际上,我们可以使用和 scikit-learn 部署时相似的架构。例如下面这张示意图,其展示了一个 TensorFlow 模型部署的经典流程。

image

这里使用了 AWS 提供的服务,首先把训练数据挂载在 AWS S3 上,并通过 GPU 实例完成训练和验证。使用 TensorFlow 提供的模型保存方法将训练完成的模型存储下来并挂载在 AWS S3 上,最后使用 TensorFlow 提供的 Serving 服务部署为 RESTful API 以供客户端调用。

其中,TensorFlow Serving 是一种灵活的高性能服务系统,适用于机器学习模型,专为生产环境而设计。TensorFlow 提供了多组示例,并演示了如何使用 Docker 容器 技术来部署模型。

81.3. ONNX 开放模型格式#

一般情况下,不同深度学习框架所储存的模型只能由相应框架所构建的代码来加载。这其实是一个很大的障碍,尤其是对于学术研究和生产环境部署时,不同的人可能有不同的框架偏好选择。

image

于是,在 微软,亚马逊,Facebook 和 IBM 等公司共同推动下,创建了 ONNX 开放模型格式。它使得不同的深度学习框架可以采用相同格式存储模型数据并交互。也就是说,一个框架中进行训练的模型也可以转移到另一个框架中进行推理。ONNX 目前 官方支持 模型并进行推理的深度学习框架有:PyTorch, TensorFlow, Caffe2, MXNet,ML.NET,TensorRT 和 Microsoft CNTK。

下面,我们来学习 ONNX 的使用。这里,首先使用 PyTorch 构建一个 MNIST 分类模型,直接沿用前面课程中的相关代码。

import torch
import torchvision

# 加载训练数据,参数 train=True,供 60000 条
train = torchvision.datasets.MNIST(root='.', train=True, download=True,
                                   transform=torchvision.transforms.ToTensor())
# 加载测试数据,参数 train=False,供 10000 条
test = torchvision.datasets.MNIST(root='.', train=False, download=True,
                                  transform=torchvision.transforms.ToTensor())
# 训练数据打乱,使用 64 小批量
train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=64,
                                           shuffle=True)
# 测试数据无需打乱,使用 64 小批量
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=64,
                                          shuffle=False)
train_loader, test_loader
(<torch.utils.data.dataloader.DataLoader at 0x292550d00>,
 <torch.utils.data.dataloader.DataLoader at 0x2924cf790>)

这里,使用 PyTorch 定义一个神经网络:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 512)  # 784 是因为训练是我们会把 28*28 展平
        self.fc2 = nn.Linear(512, 128)  # 使用 nn 类初始化线性层(全连接层)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 直接使用 relu 函数,也可以自己初始化一个 nn 下面的 Relu 类使用
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # 输出层一般不激活
        return x

model = Net()
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失函数
opt = torch.optim.Adam(model.parameters(), lr=0.002)  # Adam 优化器
model
Net(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)
def fit(epochs, model, opt):
    print("Start training, please be patient.")
    for epoch in range(epochs):
        # 从数据加载器中读取 Batch 数据开始训练
        for i, (images, labels) in enumerate(train_loader):
            images = images.reshape(-1, 28*28)  # 对特征数据展平,变成 784
            labels = labels  # 真实标签
            outputs = model(images)  # 前向传播
            loss = loss_fn(outputs, labels)  # 传入模型输出和真实标签
            opt.zero_grad()  # 优化器梯度清零,否则会累计
            loss.backward()  # 从最后 loss 开始反向传播
            opt.step()  # 优化器迭代
            # 自定义训练输出样式
            if (i+1) % 100 == 0:
                print('Epoch [{}/{}], Batch [{}/{}], Train loss: [{:.3f}]'
                      .format(epoch+1, epochs, i+1, len(train_loader), loss.item()))
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.reshape(-1, 28*28)
            labels = labels
            outputs = model(images)
            # 得到输出最大值 _ 及其索引 predicted
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()  # 如果预测结果和真实值相等则计数 +1
            total += labels.size(0)  # 总测试样本数据计数
        print('============= Test accuracy: {:.3f} =============='.format(
            correct / total))

最后,完成网络训练:

fit(epochs=1, model=model, opt=opt)  # 训练 1 个 Epoch,预计持续 10 分钟
Start training, please be patient.
Epoch [1/1], Batch [100/938], Train loss: [0.583]
Epoch [1/1], Batch [200/938], Train loss: [0.179]
Epoch [1/1], Batch [300/938], Train loss: [0.176]
Epoch [1/1], Batch [400/938], Train loss: [0.161]
Epoch [1/1], Batch [500/938], Train loss: [0.288]
Epoch [1/1], Batch [600/938], Train loss: [0.151]
Epoch [1/1], Batch [700/938], Train loss: [0.191]
Epoch [1/1], Batch [800/938], Train loss: [0.152]
Epoch [1/1], Batch [900/938], Train loss: [0.066]
============= Test accuracy: 0.965 ==============

我们可以使用 torch.save 将模型保存下来,这里只保存模型中的参数 model.state_dict()

torch.save(model.state_dict(), 'mnist.pth')

下面,我们将 PyTorch 模型转换为 ONNX 格式。目前,PyTorch → ONNX 的方法已经集成在 PyTorch 代码中,直接使用 torch.onnx.export 即可调用。

值得注意的是,保存时需要指定 dummy_input,即单个样本输入到模型的张量形状。这个形状是根据网络决定的,不是随意指定的,错误的形状会引发报错。

from torch.autograd import Variable

# 加载 PyTorch 模型参数
trained_model = Net()
trained_model.load_state_dict(torch.load('mnist.pth'))

# 导出训练好的模型为 ONNX
dummy_input = Variable(torch.randn([1, 784]))
torch.onnx.export(trained_model, dummy_input, "mnist.onnx")

有了 ONNX 模型文件,下面就可以将其转换为 TensorFlow 支持的模型了。这里我们需要利用 ONNX 官方提供的 Tensorflow Backend for ONNX 组件。

!git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow && pip install -e .

安装完成之后,请重启 Jupyter Notebook 环境内核,以保证 Tensorflow Backend for ONNX 能被正常加载。

import onnx
from onnx_tf.backend import prepare
# 加载 ONNX 模型文件
model = onnx.load('mnist.onnx')
# 导入 ONNX 模型到 Tensorflow
tf_rep = prepare(model)
tf_rep
<onnx_tf.backend_rep.TensorflowRep at 0x28a31e7d0>

ONNX 官方开放了 ONNX Model Zoo 项目,其提供了大量的预训练模型以供使用。你可以直接部署这些模型或者用于迁移学习任务中。

例如,我们可以下载 MobileNet v2 在 ImageNet 上的预训练模型:

wget -nc "https://cdn.huhuhang.com/hands-on-ai/files/mobilenetv2-1.0.onnx"
--2023-11-14 11:09:09--  https://cdn.huhuhang.com/hands-on-ai/files/mobilenetv2-1.0.onnx
正在解析主机 cdn.huhuhang.com (cdn.huhuhang.com)... 198.18.7.59
正在连接 cdn.huhuhang.com (cdn.huhuhang.com)|198.18.7.59|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度:14246826 (14M) [application/octet-stream]
正在保存至: “mobilenetv2-1.0.onnx”

mobilenetv2-1.0.onn 100%[===================>]  13.59M  3.07MB/s  用时 4.4s      

2023-11-14 11:09:17 (3.07 MB/s) - 已保存 “mobilenetv2-1.0.onnx” [14246826/14246826])

这里需要选择一个框架作为后端进行推理。

from onnx_tf import backend

mobilenetv2 = onnx.load('mobilenetv2-1.0.onnx')

此时,可以通过 mobilenetv2.graph.input[0] 查看输入层要求的数组形状。

mobilenetv2.graph.input[0]
name: "data"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 224
      }
      dim {
        dim_value: 224
      }
    }
  }
}

明显,这里需要彩色图片组成的 [1, 3, 224, 224] 形状的数组。下面,我们直接使用 NumPy 生成一组满足形状的随机数进行测试,不再使用真实图片进行推理了。

backend.run_model(mobilenetv2, np.random.randn(1, 3, 224, 224).astype(np.float32))
Outputs(mobilenetv20_output_flatten0_reshape0=array([[-2.17918873e-01,  4.75229979e+00,  3.84785867e+00,
         3.01842427e+00,  5.65132666e+00,  3.55456161e+00,
         6.01144171e+00, -2.49251819e+00,  1.34723496e+00,
        -2.51869464e+00,  2.64187407e+00,  1.88138056e+00,
         3.49288440e+00,  3.22448373e+00,  2.32277083e+00,
         1.43812585e+00,  2.60400295e-01, -1.09106526e-01,
         2.88156056e+00,  1.09300959e+00,  4.17400777e-01,
         1.68558586e+00, -2.53207147e-01,  3.04420352e+00,
        -1.04951906e+00, -2.00535917e+00, -5.25695503e-01,
         6.80131555e-01, -1.48528785e-01, -2.07431960e+00,
         2.95197988e+00,  3.99375856e-01, -6.44338846e-01,
         4.99150902e-01,  5.11501193e-01,  9.59893465e-01,
         2.31016755e+00, -3.58879834e-01,  1.81526855e-01,
        -4.59364623e-01,  1.00105858e+00,  2.04941675e-01,
        -1.07820022e+00, -1.36745608e+00, -6.82603002e-01,
        -1.48531663e+00,  1.43777537e+00, -3.11603117e+00,
         9.30862188e-01, -8.74181509e-01, -1.87898785e-01,
        -2.34333634e+00,  1.06189549e-01, -6.87525719e-02,
        -5.13176620e-01,  3.30288410e+00, -2.09389234e+00,
        -5.56719601e-01,  1.94359183e+00,  2.67331839e+00,
         2.03160882e+00, -7.85668790e-01,  1.17350090e+00,
        -2.57338524e-01,  1.79152572e+00,  3.82038021e+00,
         5.98783731e-01,  3.18593621e+00,  9.35994908e-02,
         1.90697658e+00,  7.89554536e-01,  1.47086513e+00,
        -3.03895020e+00, -1.11876011e+00, -1.38503742e+00,
         4.18293118e-01, -1.49195015e+00,  2.97811627e-01,
         3.74538350e+00,  4.11389971e+00,  1.24962711e+00,
         1.19030821e+00,  2.35595036e+00, -5.74236989e-01,
         4.39513063e+00,  1.02527726e+00,  1.87954128e+00,
         2.93551505e-01, -1.80891201e-01,  2.13482332e+00,
         6.03989482e-01,  1.77726662e+00,  1.43735981e+00,
         7.77451932e-01,  3.96851110e+00, -2.11307406e+00,
         2.72765326e+00, -6.45401537e-01, -1.99402952e+00,
         1.74325025e+00, -1.20393193e+00, -4.13333273e+00,
        -1.97058928e+00,  1.22431636e+00,  1.02035820e+00,
        -1.00788939e+00,  9.02926564e-01,  5.49970293e+00,
         1.62760937e+00,  5.25525665e+00,  3.62832665e+00,
         4.93371582e+00,  2.95067215e+00,  3.52915764e+00,
         4.20611525e+00,  1.59759164e+00, -4.17489819e-02,
         2.82020092e-01,  1.30935168e+00, -7.01524317e-01,
         2.00631475e+00, -1.31790698e+00, -7.51845360e-01,
         5.17542183e-01,  8.04478824e-01,  6.92325711e-01,
         1.58367002e+00, -2.29452944e+00, -2.54239559e+00,
        -2.46707964e+00,  9.35787499e-01, -1.14612556e+00,
         2.03999710e+00,  2.82344389e+00,  2.89001882e-01,
        -7.91936755e-01, -1.84974745e-01,  2.03111410e-01,
         6.52239621e-01, -3.18499565e-01, -8.60790491e-01,
         4.70759690e-01, -1.15969265e+00, -1.24945605e+00,
        -9.56427336e-01,  1.38786578e+00, -1.13875556e+00,
        -9.91407871e-01, -3.55776691e+00,  1.29339385e+00,
         2.17420197e+00,  7.36300528e-01, -4.05283594e+00,
        -1.54314435e+00, -2.68563819e+00, -1.59641159e+00,
        -4.48659229e+00, -3.96959138e+00, -4.03299236e+00,
        -8.60120893e-01, -5.15840626e+00, -1.04804826e+00,
        -2.61472046e-01, -1.46869779e+00, -3.80501986e+00,
        -4.32001448e+00, -1.65017724e+00, -4.22842169e+00,
        -1.52900636e+00, -3.29103899e+00, -4.82584238e+00,
         1.02863050e+00, -3.88118029e+00, -1.32468247e+00,
        -3.84607530e+00, -3.22713828e+00, -5.95544052e+00,
        -5.94322586e+00,  1.10945933e-01, -2.64615726e+00,
        -1.65175343e+00, -4.91664767e-01, -5.85528970e-01,
        -1.92149258e+00, -1.63642430e+00, -1.52183259e+00,
        -2.81970382e+00, -1.51661479e+00, -3.28479469e-01,
        -1.55888355e+00, -2.17251158e+00, -1.09711671e+00,
        -1.71703672e+00, -1.63081264e+00, -4.33476830e+00,
        -7.36499846e-01, -1.43112302e+00, -4.16999960e+00,
        -3.15626097e+00, -1.05905771e+00, -3.86533403e+00,
        -2.23281527e+00, -2.97160292e+00, -1.36809027e+00,
        -9.38341558e-01, -3.71469402e+00, -2.66338682e+00,
        -1.17009544e+00, -1.76349151e+00, -2.62561464e+00,
        -3.38372898e+00, -5.00743836e-02, -1.40081668e+00,
        -1.72890866e+00, -3.69821453e+00, -2.78117609e+00,
        -2.03931785e+00, -1.63897371e+00, -2.17677402e+00,
        -8.03870916e-01, -2.57224131e+00,  2.17500851e-01,
        -3.75606370e+00, -3.76434708e+00, -3.93135023e+00,
        -2.72206736e+00, -5.75111341e+00, -4.53542614e+00,
        -3.75281167e+00, -3.58237600e+00, -3.29762721e+00,
        -2.75329018e+00, -3.21418285e+00, -5.32918262e+00,
        -3.39491415e+00, -2.86598563e+00, -1.86687207e+00,
        -3.20478582e+00, -7.90369213e-01, -4.67666197e+00,
        -3.67946553e+00, -3.48915315e+00, -5.31197190e-01,
        -2.35967994e+00, -5.84845352e+00, -1.95235109e+00,
        -2.29551077e+00, -1.88898444e+00, -1.17966509e+00,
        -2.10919929e+00, -1.05902064e+00, -4.74287122e-01,
        -2.88320017e+00, -1.84879971e+00, -2.48982525e+00,
        -3.87248755e+00, -2.94255137e+00, -3.66363645e+00,
        -3.00854874e+00, -2.09684920e+00, -1.36282849e+00,
        -4.35816479e+00, -5.47735167e+00, -2.51229501e+00,
        -1.93667758e+00, -7.76217759e-01, -1.21973228e+00,
        -2.63028979e+00, -3.49632978e+00, -2.19640136e+00,
        -2.88776255e+00, -1.34260690e+00, -7.36173093e-01,
        -1.68438911e+00, -2.64637423e+00, -2.64436007e+00,
        -2.34884000e+00,  1.91925913e-01, -5.27190268e-01,
        -6.24125302e-01, -2.50945866e-01,  7.92718172e-01,
         1.52890730e+00,  1.70753753e+00,  9.83535469e-01,
         2.34873995e-01, -2.07392365e-01, -1.63242733e+00,
        -7.03844905e-01, -4.43327236e+00, -2.51538181e+00,
        -4.06490612e+00, -3.57556748e+00, -1.45213950e+00,
        -3.51834416e-01, -1.36355627e+00,  1.69787848e+00,
        -2.17246938e+00,  2.22532058e+00,  1.64712918e+00,
        -1.50832975e+00,  4.41281796e+00, -8.70427191e-01,
         1.72901881e+00,  1.80536377e+00, -7.54105508e-01,
         1.04408884e+00,  8.39269012e-02,  2.91959333e+00,
         6.06098771e-01,  4.23403835e+00,  1.31669974e+00,
         7.38263428e-01,  2.06975985e+00,  4.62217808e+00,
         1.43908453e+00,  2.14631677e+00,  5.04689598e+00,
         4.76549292e+00,  3.72261226e-01,  1.25207007e+00,
        -3.31836653e+00, -1.90342593e+00, -4.12565112e-01,
        -7.28223264e-01, -8.86083186e-01, -1.87105370e+00,
         1.43135321e+00,  1.34187758e+00,  8.42997611e-01,
         3.12780976e+00,  1.98254490e+00, -6.48530841e-01,
        -2.27205229e+00, -1.22997165e-03,  9.08855796e-01,
         6.59826517e-01,  1.79793310e+00, -2.56135345e+00,
        -3.25729394e+00, -4.21057272e+00, -3.68100023e+00,
        -3.22765446e+00, -3.54969716e+00, -7.71103680e-01,
        -2.01898861e+00, -3.57962084e+00, -2.90791678e+00,
        -2.83112526e+00, -4.18578625e+00, -2.43274665e+00,
        -3.58720398e+00, -2.58402967e+00, -6.03637516e-01,
        -1.40604544e+00, -2.33785391e+00,  6.65213466e-01,
        -1.50168669e+00, -1.35840964e+00, -5.11506975e-01,
         1.18511474e+00,  3.92596126e-01,  8.20672274e-01,
        -2.11432171e+00,  2.12864828e+00,  1.13581800e+00,
        -1.24782968e+00, -9.32157636e-01, -1.69053090e+00,
        -2.37530756e+00, -1.72154319e+00, -1.14811516e+00,
        -3.62779117e+00, -2.64389420e+00, -2.55285883e+00,
        -2.77751994e+00, -2.61594486e+00,  9.10730183e-01,
        -6.75414622e-01, -8.72081757e-01,  1.23529665e-01,
        -1.22607601e+00,  8.83704275e-02, -1.90711156e-01,
        -3.25564575e+00, -2.29798436e+00, -2.89960575e+00,
        -3.20943141e+00, -4.03142023e+00,  3.75934863e+00,
         4.54327917e+00,  3.58013487e+00,  2.20520234e+00,
         2.48583412e+00,  7.01974750e-01,  5.11725426e+00,
         1.59577084e+00,  4.27544498e+00,  5.25448418e+00,
         2.86719322e-01,  2.23582327e-01, -1.35090923e+00,
        -2.04949045e+00, -5.80793619e+00, -2.15647578e+00,
         1.26754344e-02, -3.73078728e+00, -3.78148246e+00,
        -3.22023702e+00,  1.44259477e+00, -6.48111105e-02,
         5.74393988e+00,  2.34845138e+00,  1.30464864e+00,
         2.26204705e+00, -4.66091156e+00,  8.15229893e-01,
         2.91571045e+00,  2.99940419e+00,  3.30518174e+00,
        -1.09469771e-01,  3.44727826e+00, -1.79197526e+00,
        -3.62107754e+00, -1.49195623e+00, -6.43744588e-01,
         6.64691925e-01,  6.12119138e-02, -6.82851374e-01,
         2.44240904e+00, -8.64460826e-01, -1.01271741e-01,
         2.16802549e+00,  1.24422276e+00,  3.73224902e+00,
         1.27984703e+00, -1.05274928e+00,  5.63550532e-01,
         1.55162263e+00, -5.65338755e+00,  4.46175051e+00,
         3.22314143e+00, -3.01666045e+00,  6.50161028e+00,
        -1.43130112e+00,  7.39034414e-01,  6.55038929e+00,
         1.07933432e-01,  1.05543464e-01, -2.82925797e+00,
        -8.58651257e+00,  3.00574160e+00,  2.00312328e+00,
        -9.90903974e-01, -1.25102961e+00,  7.57402134e+00,
        -1.85941625e+00,  2.24710894e+00,  3.17299843e+00,
         2.58878326e+00, -2.74121225e-01,  2.05823374e+00,
         5.07509851e+00,  2.73586559e+00,  9.25174892e-01,
         1.95816851e+00, -6.41819191e+00, -1.15951777e+00,
        -5.76262653e-01,  3.47740710e-01,  2.95022869e+00,
        -5.19253016e-01, -2.18517804e+00, -1.50224710e+00,
         4.93859339e+00,  6.35736287e-01, -2.68830633e+00,
        -2.66313934e+00, -9.75158036e-01, -1.56268227e+00,
         1.98537946e+00,  7.43496060e-01, -3.28562093e+00,
        -1.46836996e+00, -1.44725740e+00, -2.20083427e+00,
        -1.27595425e+00, -3.88760418e-01,  7.31318521e+00,
         5.02191591e+00,  7.00381279e+00, -2.29816723e+00,
         1.37847555e+00, -1.82301685e-01, -3.73122394e-02,
        -2.66596675e+00,  1.23226404e+00, -1.47836566e+00,
         5.87795079e-01,  9.18014348e-01, -5.98944783e-01,
         1.01020169e+00,  3.56583238e-01,  1.30092525e+00,
         1.85662901e+00, -1.57888710e+00,  4.12323236e+00,
         1.10063553e+00,  2.89762282e+00,  9.19300139e-01,
        -3.42635584e+00, -5.18322289e-01,  8.09627831e-01,
         5.76076627e-01, -6.84943795e-03,  9.87595797e-01,
         3.01800179e+00,  1.63622022e+00,  5.92687547e-01,
        -1.18962634e+00,  2.92661428e+00, -2.93515229e+00,
         2.14856744e+00,  9.06814992e-01,  1.18594790e+00,
        -5.48650599e+00, -2.43133569e+00,  1.41218221e+00,
        -9.85596955e-01,  2.16115236e+00,  2.23670030e+00,
        -1.04340482e+00, -2.75633931e+00,  3.78816271e+00,
         2.44632030e+00, -6.47061539e+00, -1.80698502e+00,
        -6.02159595e+00,  4.12956595e-01,  4.37763262e+00,
        -7.84419537e-01,  1.12717450e-02,  2.84427905e+00,
         9.60683227e-01, -2.13115954e+00,  1.43547511e+00,
        -7.94499159e-01, -8.27979374e+00, -2.70068288e+00,
         6.06783819e+00, -3.16984868e+00, -1.35335720e+00,
         2.16869187e+00,  1.19414377e+00,  3.29736024e-01,
         4.40872461e-01,  6.24655056e+00,  2.47257924e+00,
         9.67887163e-01,  1.13619447e+00, -3.23841989e-01,
        -4.37616825e+00,  6.20965624e+00,  1.00040543e+00,
        -1.22685170e+00, -1.31229079e+00,  4.82931316e-01,
        -3.83289158e-01, -3.28239709e-01, -3.34979701e+00,
         1.61709106e+00, -2.76014161e+00,  1.06432438e+00,
        -5.19821835e+00,  5.60067940e+00, -2.08112836e+00,
        -2.84561920e+00, -1.31095946e-02,  2.92052031e-01,
        -1.37379873e+00,  1.89190865e+00,  1.29148376e+00,
        -2.34446001e+00,  1.19223821e+00,  1.64357221e+00,
        -9.26005095e-02, -3.86853361e+00,  5.12957621e+00,
         3.71036744e+00, -2.37145138e+00,  4.86025155e-01,
         4.49000025e+00, -2.49683475e+00, -9.00746942e-01,
        -6.81072712e-01, -1.76341474e+00,  3.44168496e+00,
         1.17514050e+00,  6.07610941e-01,  4.19197273e+00,
         5.43518209e+00,  6.23007298e-01,  1.97886705e+00,
        -6.68706465e+00,  3.65284991e+00, -6.29918873e-01,
         7.33301282e-01, -9.07186866e-01,  4.68629313e+00,
        -3.59519076e+00,  3.91423202e+00,  3.29417777e+00,
        -5.25915051e+00,  1.06234103e-01,  2.58010328e-01,
         2.15742916e-01,  3.48545957e+00,  8.93554986e-01,
         3.05062294e+00,  1.92996025e+00,  3.20788741e+00,
         1.53058574e-01,  3.57414436e+00,  1.72561216e+00,
        -1.91309083e+00, -1.47555196e+00,  3.94888544e+00,
        -5.78994894e+00, -8.49439919e-01,  4.83203903e-02,
        -1.12441611e+00,  2.27104449e+00,  2.48128462e+00,
         3.74482441e+00, -4.11819410e+00,  9.37119603e-01,
         2.58896565e+00,  1.35843658e+00,  8.43623459e-01,
         2.59628743e-01,  3.52006865e+00, -5.37169337e-01,
        -6.62730336e-01,  1.72018260e-01,  4.91915894e+00,
        -3.64804626e-01,  3.90878272e+00,  1.92415154e+00,
        -1.62087131e+00,  1.48866689e+00,  4.85203409e+00,
         7.46661365e-01,  1.21980146e-01, -1.63700008e+00,
        -6.12033939e+00,  3.59952378e+00, -2.84760594e+00,
        -3.82449865e-01,  1.87183976e+00,  2.36486816e+00,
        -2.24170923e+00, -4.47580862e+00,  3.74309719e-02,
        -6.49170697e-01,  1.00716007e+00, -3.91869378e+00,
        -6.10592961e-03, -1.60673749e+00,  7.83551991e-01,
         2.89197612e+00, -1.58202517e+00, -7.28074265e+00,
        -2.04833770e+00,  9.64782983e-02,  1.38752127e+00,
        -5.41232347e+00, -2.56652069e+00,  4.89874363e+00,
        -9.59260702e-01,  3.47580433e+00,  2.53640366e+00,
         1.18827975e+00,  1.02208757e+00,  1.72018898e+00,
         1.39889455e+00, -3.03857517e+00,  4.34782147e-01,
        -4.03067780e+00,  2.39652231e-01,  1.31940758e+00,
        -6.51659822e+00,  9.37106252e-01,  3.39382839e+00,
         1.24303472e+00, -2.99718857e+00,  1.47961664e+00,
         4.17689323e+00,  1.80935299e+00,  9.49615777e-01,
         1.35060549e+00,  5.54529476e+00,  2.80935812e+00,
         1.77588618e+00, -1.93963408e+00,  2.26257205e+00,
        -5.23418617e+00, -1.47641885e+00,  2.23707885e-01,
         4.02806103e-01,  2.08055758e+00, -1.35911429e+00,
         3.13439107e+00,  3.43396163e+00, -8.72751474e-01,
         1.85963941e+00, -1.09747708e+00,  4.22233045e-01,
        -5.23085880e+00, -2.75562072e+00,  1.54278719e+00,
         2.58603501e+00,  3.91459727e+00,  2.44442177e+00,
         2.48551559e+00, -1.59395587e+00,  3.24753332e+00,
        -4.16821480e+00,  2.06871295e+00,  5.62646389e+00,
         4.64301109e-01, -2.69743824e+00,  5.52843618e+00,
        -2.47451735e+00,  6.38625956e+00, -4.35492516e+00,
         4.35812759e+00,  8.53062332e-01,  4.31005955e+00,
         1.25292078e-01, -4.19782352e+00,  4.19049561e-01,
         5.45236969e+00,  1.74472284e+00, -9.98820603e-01,
        -6.72462463e-01, -2.82103109e+00,  2.88248897e+00,
         6.06165454e-02,  3.59094048e+00,  4.48974705e+00,
         4.85561943e+00, -3.97314668e+00, -8.82922351e-01,
        -1.75218880e-01,  1.13065112e+00,  2.84188718e-01,
        -1.51487100e+00, -5.67327118e+00,  2.38555026e+00,
        -1.58298981e+00,  4.40290976e+00,  2.60201454e+00,
        -1.74241769e+00,  2.39598274e+00,  3.63118005e+00,
        -2.69875407e+00, -2.21399093e+00,  1.20327568e+00,
         1.17902851e+00,  3.00376868e+00, -2.52832174e-01,
         4.09321517e-01,  3.02216554e+00,  2.31983042e+00,
        -2.72359967e-01,  3.35919046e+00,  2.15303373e+00,
         2.02424216e+00,  3.08725119e+00, -2.66263318e+00,
        -3.59353960e-01, -1.96025133e+00,  2.54136467e+00,
         4.97461700e+00,  4.54417753e+00, -5.40532947e-01,
        -1.24374437e+00,  4.96543467e-01, -1.35619867e+00,
         2.85133219e+00, -9.65558767e-01,  1.90204120e+00,
         4.88859463e+00,  1.48145902e+00,  7.23354721e+00,
        -9.65938330e-01, -3.72505367e-01,  1.91981864e+00,
         8.32346499e-01,  8.80843282e-01, -1.81890094e+00,
         1.60175407e+00, -3.47283578e+00, -3.79743385e+00,
         9.37828600e-01,  1.20582545e+00,  3.22654080e+00,
        -9.99832630e-01,  7.62916088e-01,  3.58098477e-01,
        -4.32951152e-01,  1.65683150e+00, -1.39699662e+00,
         2.50779557e+00,  8.53580475e-01,  2.93004155e+00,
         2.51207113e+00, -4.64979887e+00,  5.98403645e+00,
         6.55933678e-01, -5.23787642e+00, -2.75699520e+00,
        -1.13288140e+00,  1.88558429e-01,  6.82950830e+00,
         2.11598665e-01, -1.46717060e+00, -1.32634652e+00,
         4.51558256e+00, -3.03100300e+00, -7.77974784e-01,
         2.04075456e-01, -7.43443370e-01, -5.15884447e+00,
         3.28733373e+00,  6.00705445e-01,  6.30431056e-01,
         2.62039256e+00,  1.22509456e+00, -2.66100073e+00,
         4.62406683e+00,  3.91742325e+00,  3.08012056e+00,
         2.90776849e+00,  5.88140869e+00,  1.99387634e+00,
         2.40093440e-01, -3.97538590e+00, -3.37095809e+00,
         2.69452596e+00, -2.01809734e-01,  2.21964526e+00,
         3.04176974e+00, -2.10473251e+00,  4.89241958e-01,
         2.49010754e+00, -4.15451241e+00, -2.16438746e+00,
         6.58748674e+00,  3.50101054e-01, -1.13290405e+00,
         3.46972680e+00,  1.63381696e+00, -1.91263092e+00,
        -3.02048779e+00, -2.84515905e+00, -1.28598166e+00,
        -3.28682423e+00,  3.51568794e+00,  7.56309181e-02,
        -1.41291094e+00, -1.51283979e+00,  2.29272985e+00,
        -8.75395238e-02, -6.11999798e+00,  7.28092015e-01,
         7.14829922e-01, -2.08875370e+00, -1.30002409e-01,
         4.10728931e+00, -2.94424844e+00, -8.09582233e-01,
         2.36096597e+00,  3.51311040e+00, -1.29023716e-01,
         4.02777481e+00, -1.50548649e+00, -9.65521693e-01,
        -3.00948977e+00, -3.19941044e-01, -9.31656480e-01,
        -7.05912888e-01,  3.79872441e+00,  2.44292140e+00,
         2.78852820e+00,  7.48423636e-01,  2.39480758e+00,
        -7.10983872e-01,  5.03461933e+00, -3.08670342e-01,
        -1.41015375e+00, -1.52869582e+00,  1.13087010e+00,
         1.16927892e-01,  7.62843037e+00,  3.64698887e+00,
         4.31580687e+00,  2.14642882e+00,  4.92044210e+00,
         4.04457033e-01,  3.06233597e+00,  6.21044493e+00,
        -1.03981376e+00,  3.12776518e+00, -1.87236428e+00,
        -1.02444792e+00,  6.09659851e-01, -9.93071079e-01,
         1.49931467e+00,  1.97498417e+00,  2.45144749e+00,
         1.93861985e+00,  2.16339493e+00, -2.71444654e+00,
        -3.48313403e+00, -2.40161180e+00, -2.22023773e+00,
        -1.32474422e+00,  1.41033149e+00,  3.52650666e+00,
         5.02738468e-02, -9.12927032e-01,  1.15016484e+00,
        -3.56898141e+00, -4.13017654e+00,  1.37760448e+00,
        -1.75844884e+00,  2.75605083e+00, -4.61719424e-01,
        -2.06876278e+00, -2.98703384e+00, -2.61709976e+00,
        -2.99704480e+00,  6.29932284e-01, -4.80658340e+00,
         1.16191745e+00, -2.05834651e+00, -4.85292017e-01,
         7.60220289e-02,  6.74728930e-01,  2.60395098e+00,
         2.87417603e+00, -1.01589990e+00,  7.75488138e-01,
        -6.88945293e-01,  1.12021005e+00, -2.92550373e+00,
         7.58213341e-01,  1.60940790e+00, -2.70977211e+00,
        -9.32885349e-01,  2.71826744e-01, -3.01931286e+00,
        -2.60964465e+00, -4.23301506e+00, -5.37273645e+00,
         2.77103806e+00,  1.36932266e+00,  3.88070369e+00,
        -1.53155482e+00, -2.64897084e+00,  1.03036194e+01,
        -1.82979715e+00,  4.68517780e+00,  1.45898998e+00,
         4.71672773e-01,  5.28277196e-02,  3.51470685e+00,
         3.54170036e+00, -9.84236300e-01,  1.58287191e+00,
         2.72989094e-01, -1.11029014e-01,  3.28318739e+00,
         1.86232042e+00,  3.03744006e+00, -2.53968024e+00,
         3.25116205e+00,  2.03336453e+00, -4.36543897e-02,
         1.87464237e+00, -3.22686434e-01,  7.00683072e-02,
        -4.40989733e+00, -1.94826424e+00, -3.23508596e+00,
        -3.22833610e+00, -3.03528214e+00,  2.86764312e+00,
         1.89342952e+00]], dtype=float32))

可以看到,backend.run_model(model, data) 最终返回了推理结果,这是 Softmax 概率,需要对应到原 ImageNet 数据集的标签中去。由于实验已经完成主要目标,且本身是传入的随机数据,这里就不再对应最终标签了。

81.4. 总结#

本次实验,我们重点学习了深度学习模型训练和部署过程中可能会用到的 ONNX 开放格式。并选择 PyTorch 训练模型后,使用 TensorFlow 重载模型用于推理。目前,ONNX 对于可框架的支持还不算完美,只有部分框架支持模型导出。同时,我们推荐使用 Caffe2 作为 Backend 端用于推理,相对而言可能会遇到更少的 Bug。你可以通过 ONNX 的仓库页面关注 各类框架的支持进度

相关链接