ONNX Runtime介绍_onnxruntime-程序员宅基地

技术标签: PyTorch  ONNX Runtime  Deep Learning  

      ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:

      1.ONNX Runtime Inferencing:高性能推理引擎

      (1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;

      (2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;

      (3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;

      (4).在Python中训练,确可部署到C++/Java等应用程序中。

      2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。

      通过conda命令安装执行:

conda install -c conda-forge onnxruntime

      以下为测试代码:通过ResNet-50对图像进行分类

import numpy as np
import onnxruntime
import onnx
from onnx import numpy_helper
import urllib.request
import os
import tarfile
import json
import cv2

# reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
def download_onnx_model():
    labels_file_name = "imagenet-simple-labels.json"
    model_tar_name = "resnet50v2.tar.gz"
    model_directory_name = "resnet50v2"

    if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):
        print("files exist, don't need to download")
    else:
        print("files don't exist, need to download ...")

        onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"
        imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"

        # retrieve our model from the ONNX model zoo
        urllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)
        urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)

        print("download completed, start decompress ...")
        file = tarfile.open(model_tar_name)
        file.extractall("./")
        file.close()

    return model_directory_name, labels_file_name

def load_labels(path):
    with open(path) as f:
        data = json.load(f)
    return np.asarray(data)

def images_preprocess(images_path, images_name):
    input_data = []

    for name in images_name:
        img = cv2.imread(images_path + name)
        img = cv2.resize(img, (224, 224))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        data = np.array(img).transpose(2, 0, 1)
        #print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")
        # convert the input data into the float32 input
        data = data.astype('float32')

        # normalize
        mean_vec = np.array([0.485, 0.456, 0.406])
        stddev_vec = np.array([0.229, 0.224, 0.225])
        norm_data = np.zeros(data.shape).astype('float32')
        for i in range(data.shape[0]):
            norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]

        # add batch channel
        norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
        input_data.append(norm_data)

    return input_data

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

def inference(onnx_model, labels, input_data, images_name, images_label):
    session = onnxruntime.InferenceSession(onnx_model, None)
    # get the name of the first input of the model
    input_name = session.get_inputs()[0].name
    count = 0
    for data in input_data:
        print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")
        count += 1

        raw_result = session.run([], {input_name: data})

        res = postprocess(raw_result)

        idx = np.argmax(res)
        print(f"  result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")

        sort_idx = np.flip(np.squeeze(np.argsort(res)))
        print("  top 5 labels are:", labels[sort_idx[:5]])

def main():
    model_directory_name, labels_file_name = download_onnx_model()

    labels = load_labels(labels_file_name)
    print("the number of categories is:", len(labels)) # 1000

    images_path = "../../data/image/"
    images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
    images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]
    if len(images_name) != len(images_label):
        print("Error: images count and labes'length don't match")
        return

    input_data = images_preprocess(images_path, images_name)

    onnx_model = model_directory_name + "/resnet50v2.onnx"
    inference(onnx_model, labels, input_data, images_name, images_label)

    print("test finish")

if __name__ == "__main__":
    main()

      测试图像如下所示:

      执行结果如下所示:

 

      GitHub: https://github.com/fengbingchun/PyTorch_Test

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/fengbingchun/article/details/125951896

智能推荐

【WebApi】————.net WebApi开发(一)_webapi .net-程序员宅基地

文章浏览阅读8.5k次。【1】.部署环境.net4及以上版本。【2】.vs2010 开发需单独安装vs2010 sp1和mvc4mvc4:http://www.asp.net/mvc/mvc4【3】.开发1.新建项目选择ASP.net MVC 4 Web应用程序2.选择Web API 3.在新建立的项目里面有已经生成的webapi模版其中App_Start文件夹下WebApiCo..._webapi .net

几招教你阻止百度搜索自动跳转百度APP(其他网站也适用)!_百度自动跳转app怎么解决-程序员宅基地

文章浏览阅读10w+次,点赞15次,收藏33次。最近阿虚看到个消息说「百度」发布了新政策,禁止网站通过搜索引擎打开后折叠内容强迫下载APP客户端听起来似乎是百度难得良心一回?但实际上该政策仅限于手机百度APP内如果你是通过浏览器用百度搜索则与新政策完全没关系正好前不久不少粉丝来问过我这样一个问题:怎么屏蔽手机浏览器上的「跳转某某APP打开查看」提示那今天阿虚就来教一下怎么解决吧,毕竟这东西的确是有点烦人…屏蔽「跳转某某APP打开查看」这个问题我细看了下,还得分俩类:文章只能显示部分,然后提示你需要安装APP才能查看的,这种应该是大_百度自动跳转app怎么解决

PHP快速入门12-异常处理,自定义异常、抛出异常、断言异常等示例_php 抛出异常-程序员宅基地

文章浏览阅读843次。PHP的异常处理机制可以帮助我们在程序运行时遇到错误或异常情况时,及时发出警告并停止程序继续运行。下面是10个例子,分别展示了PHP异常处理的不同用法。_php 抛出异常

linux 清空docker容器日志_linux清理docker容器log-程序员宅基地

文章浏览阅读221次。【代码】linux 清空docker容器日志。_linux清理docker容器log

青岛大学开源OJ平台搭建_github oj开源-程序员宅基地

文章浏览阅读7.3k次,点赞3次,收藏15次。源码地址为:https://github.com/QingdaoU/OnlineJudge可参考的文档为:https://github.com/QingdaoU/OnlineJudgeDeploy/tree/2.0一、安装所依赖的环境sudo apt-get update && sudo apt-get install -y vim python-pip curl g..._github oj开源

浅谈数据安全-程序员宅基地

文章浏览阅读4.4k次。在《网络安全法》中,虽然已经明确了要求保障网络数据的完整性、保密性、可用性的能力,但随着近些年数据安全热点事件的出现,如数据泄露事件、个人信息滥用事件。表明对数据保护的要求仅依赖《网络安全法》中的几款条例是不足以支撑的。因此2021年9月1日《中华人民共和国数据安全法》便正式诞生,从此数据安全也被推上了风口浪尖。那么数据安全如何定义?与传统网络安全有何区别?数据安全体系又应该如何建立?..._数据安全

随便推点

docker安装及部署mysql_docker部署mysql-程序员宅基地

文章浏览阅读1.5k次,点赞2次,收藏9次。docker安装与mysql部署_docker部署mysql

联想笔记本G510升级固态硬盘(SSD)血泪教程!!!_联想g510更换固态硬盘-程序员宅基地

文章浏览阅读8.5w次,点赞23次,收藏55次。#联想笔记本G510升级固态硬盘(SSD)血泪教程!!!用了5年的联想笔记本G510,经过了四年的游戏历程,然后四年后还老当益壮的挣扎在我工作的战斗一线,是我并肩作战多年,比兄弟还要亲的兄弟,虽然此时已经身躯残破,反应迟缓我依旧不舍得抛弃它(主要是没钱!)然后为了我个人的用户体验决定花少量的票子,让它多挣扎一会,最好是能坚持到我度过贫困期. 下面是我升级的悲催历程! - 首先为了提升运行速..._联想g510更换固态硬盘

问题记录——正则表达式匹配控制符_正则表达式匹配控制字符-程序员宅基地

文章浏览阅读910次。问题前端用xterm.js通过websocket连接docker虚拟终端,返回的字符中包括如下字符串,其中有两个控制字符,“ESC"和"BEL” ,想通过正则表达式匹配这一段字符,然后去掉这段字符:参考文档控制字符编码表转义符对照表通过上面查询得知,"ESC"和"BEL"这两个控制符的ASCII码分别为:十进制为27和7,十六进制为0x1B和0x07,转义符分别为:\e和\a代码**注意:**直接使用ASCII码匹配是不行的,一定要用转义符才行。如下测试代码中,只有regex3才能匹_正则表达式匹配控制字符

Android RIL框架分析-程序员宅基地

文章浏览阅读1.5k次。1.RIL框架 RIL,Radio Interface Layer。本层为一个协议转换层,提供Android Telephony与无线通信设备之间的抽象层。 Android RIL位于Telephony Frameworks之下,Modem之上的,根据源码,RIL可以分为两个部分:Frameworks 框架层中的java程序,简称RILJ。HAL层中C/C++程序,简称RILC,RILC具体的又包括LibRIL、Rild和Reference-RIL这三个部分。 Andr..._ril框架

Python编程基础:第六节 math包的基础使用Math Functions_ps math function-程序员宅基地

文章浏览阅读565次。第六节 math包的基础使用前言实践前言我们通常会对数值型变量进行计算,这里我们给出一些常用的函数用于辅助你的计算过程。常用的数学计算函数均在math包。实践首先我们导入math包,并定义一个浮点型变量pi将其赋值为3.14:import mathpi = 3.14如果我们需要计算浮点型变量四舍五入后的计算结果,用函数round()即可:print(round(pi))>>> 3如果我们需要向上取整,那就需要函数math.ceil():print(math.cei_ps math function

canal异常 Could not find first log file name in binary log index file_canal could not find first log file name in binary-程序员宅基地

文章浏览阅读4.4k次,点赞3次,收藏2次。Could not find first log file name in binary log index file问题解决解决过程问题最近在使用canal来监测数据库的变化,处理变动的数据。由于有一段时间没有用了,这次启动在日志文件中看到这个异常 Could not find first log file name in binary log index file,详细信息如下:2020-12-16 19:14:42.053 [destination = tradeAndRefund , addr_canal could not find first log file name in binary log index file