服务端大多都是用Java做的,而深度学习模型大多又是用Python写的,所以很多人都是用Java调Python的接口,这样效率低,而且也不优雅,最重要的是如果想使用Android做推理,那就必须要用Java写了。
本文使用了一个重要的工具:Deep Java Library,这是一个用Java进行深度学习的库,你可以用它来进行模型推理,甚至是训练模型。很多文章也都介绍过该模型,但是他们都漏了一个重要的内容:深度学习代码不只是推理部分,还有很多预处理和后续处理的部分需要很多Tensor操作,但是他们都没说怎么做。
为了符合大家的实际需求,本文不使用DJL进行模型训练,只做推理。本文的具体内容包括:
- DJL核心内容讲解
- DJL加载Pytorch模型
- DJL的Tensor操作
- DJL简单案例(DJL使用Pytorch模型完成图片分类)
DJL是一个开源的深度学习 Java 框架(支持Android),其可以用于深度学习模型构建和训练、Tensor操作、使用预训练好的常见模型(MXNet、Pytorch、TensorFlow等)。Java1.8 以上就可以用,且支持GPU
在实际案例之前,先讲解下DJL的核心API,这样在后续的案例也知道代码是做什么的。
2.2.1 Criteria
Criteria 类对象定义了模型的情况,如模型路径、输入和输出等。
例如,这是一段初始化DJL模型的代码:
在上述代码中,Criteria描述了模型的情况,主要包含以下几点:
- 定义了模型输入和输出。这里的 和 可以是自定义的类,也可以使用DJL提供的类。
- :这个代码是必须的。直接从泛型的,是获取不到输入和输出的对象的,所以需要手动设置一下。
- :模型的输入和输出是一个Tensor类型。这里就是设置你的类和类应该如何与Tensor类型进行转化。后续会具体讲。
- :设置一下模型名称
定义好模型的情况,就可以使用方法实例化出 Model Zoo 对象了。
Model Zoo 是DJL的模型,你需要通过该类对象对模型进行进行管理,例如创建模型、创建Predictor,保存模型等。
2.2.2 Translator
在上一节中,模型的输入类和输出类是可以自定义的,但Pytorch模型不可能接收你自己定义的类对象啊,它只会接受Tensor类型,所以我们就需要使用接口来定义如何将我们的自定义输入输出类转换为Tensor类型。
接口包含两个接口:
- :将输入类对象转化为Tensor。这里的Input就是输入类对象,而就是Tensor的集合(因为模型的forward可能会接收多个Tensor参数)。在DJL中,Tensor对应的类为(类似numpy中的ndarray),后续会详细讲解。
- :将模型输出的Tensor转换为自定义类。由于模型可能会输出多个Tensor,所以这里也是。
上述这两个方法还包含一个重要的参数,这个保存了Translator的上下文,可以用它来拿到一些对象(Model, Predictor等),也可以通过 和 方法来存取一些东西。
在官方的例子中,Translator是对图像进行处理,但Translator并非只能处理图像,这里的Input和Output可以是任意Java类。
2.2.3 NDArray
在python中,我们有numpy,而在Java中,我们有DJL的NDArray,使用该类,我们几乎可以实现Numpy中的所有Tensor操作。本节将会介绍常用的tensor操作。
开始前先介绍与NDArray相关的几个类:
- :相当于,可以通过方法获取其shape
- :NDArray的管理类,全局new一个就行了,需要用该类对象创建NDArray
- :用于对Tensor进行切片
- : 创建NDArray的时候,需要指定Shape。获取NDArray的Shape时返回的也是该类的对象。
接下来开始具体演示Tensor的常见操作(这里只举几个例子,有不会的操作可以在评论区告知,我会进行补充):
创建NDArray(Tensor)
创建一个Shape为的Tensor
ndManager全局应只创建一个
指定值创建:
变更数据类型
变为float类型
变为float数组:
注意,在toArray()前需要将NDArray转变为相对应的类型,且字节数要对上。例如在java中float是使用32个bit(4个字节)存储的,所以NDArray的类型必须是Float32,不能是Float64,否则会报错。
运算
加减乘除:
也可以使用,类似:
切片
等价于python中的
DJL的切片好像不能指定index,例如 x = [1,2,3], y = [2,3,4],然后切片 nums[x, y]。 DJL中我还没找到应该如何这样切,所以我只能自己用for循环实现,如果大家知道怎么弄,欢迎在评论区告诉我
赋值
等价于Python的
翻转
在Python中,对数组进行翻转可以使用,但java中不行,但可以利用函数实现
2.2.3 Predictor
创建好模型后,需要new一个Predictor,然后用这个进行预测:
到这里DJL常用的API就讲完了,接下来使用一个简单的案例进行实战。
这里使用Pytorch提供的resnet18模型完成一个图片分类任务。
- 首先引入依赖:
- 导出pytorch的resnet18模型:
- 将导出的模型拷贝到项目的model目录下:
- 创建Translator,这里我们定义输入为类型,表示图片的输入路径;输出也为,表示类别。将图片送入Resnet18网络,需要做一些预处理:
这里利用Java的NDArray的
- 定义,然后实例化模型,并new
- 准备一张图片,我这里放在项目的test目录下:
- 进行预测
由于resnet可以识别1000个物体,太多了,所以我只输出了index,全部的类别可以到该链接查找。最终输出为:
258对应的类别为Samoyed(萨摩耶),可以看得到预测对了。
DJL更多的例子可以参考官方Demo。
Deep Java Library官方文档:https://docs.djl.ai/
Dive Into Deep Learning: https://d2l.djl.ai/chapter_preliminaries/ndarray.html
djl-demo: https://github.com/deepjavalibrary/djl-demo
版权声明:
本文来源网络,所有图片文章版权属于原作者,如有侵权,联系删除。
本文网址:https://www.bianchenghao6.com/java-jiao-cheng/9268.html