欢迎访问我的GitHub

https://github.com/zq2599/blog_demos

内容:所有原创文章分类汇总及配套源码,涉及Java、Docker、Kubernetes、DevOPS等;

本篇概览

  • 作为《DL4J》实战的第三篇,目标是在DL4J框架下创建经典的LeNet-5卷积神经网络模型,对MNIST数据集进行训练和测试,本篇由以下内容构成:
  1. LeNet-5简介
  2. MNIST简介
  3. 数据集简介
  4. 关于版本和环境
  5. 编码
  6. 验证

LeNet-5简介

  • 是Yann LeCun于1998年设计的卷积神经网络,用于手写数字识别,例如当年美国很多银行用其识别支票上的手写数字,LeNet-5是早期卷积神经网络最有代表性的实验系统之一
  • LeNet-5网络结构如下图所示,一共七层:C1 -> S2 -> C3 -> S4 -> C5 -> F6 -> OUTPUT

  • 按照上图简单分析一下,用于指导接下来的开发:
  1. 每张图片都是28*28的单通道,矩阵应该是[1, 28,28]
  2. C1是卷积层,所用卷积核尺寸5*5,滑动步长1,卷积核数目20,所以尺寸变化是:28-5+1=24(想象为宽度为5的窗口在宽度为28的窗口内滑动,能滑多少次),输出矩阵是[20,24,24]
  3. S2是池化层,核尺寸2*2,步长2,类型是MAX,池化操作后尺寸减半,变成了[20,12,12]
  4. C3是卷积层,所用卷积核尺寸5*5,滑动步长1,卷积核数目50,所以尺寸变化是:12-5+1=8,输出矩阵[50,8,8]
  5. S4是池化层,核尺寸2*2,步长2,类型是MAX,池化操作后尺寸减半,变成了[50,4,4]
  6. C5是全连接层(FC),神经元数目500,接relu激活函数
  7. 最后是全连接层Output,共10个节点,代表数字0到9,激活函数是softmax

MNIST简介

  • MNIST是经典的计算机视觉数据集,来源是National Institute of Standards and Technology (NIST,美国国家标准与技术研究所),包含各种手写数字图片,其中训练集60,000张,测试集 10,000张,
  • MNIST来源于250 个不同人的手写,其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员.,测试集(test set) 也是同样比例的手写数字数据
  • MNIST官网:http://yann.lecun.com/exdb/mnist/

数据集简介

  • 从MNIST官网下载的原始数据并非图片文件,需要按官方给出的格式说明做解析处理才能转为一张张图片,这些事情显然不是本篇的主题,因此咱们可以直接使用DL4J为我们准备好的数据集(下载地址稍后给出),该数据集中是一张张独立的图片,这些图片所在目录的名字就是该图片具体的数字,如下图,目录0里面全是数字0的图片:

  • 上述数据集的下载地址有两个:
  1. 可以在CSDN下载(0积分):https://download.csdn.net/download/boling_cavalry/19846603
  2. github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
  • 下载之后解压开,是个名为mnist_png的文件夹,稍后的实战中咱们会用到它

关于DL4J版本

  • 《DL4J实战》系列的源码采用了maven的父子工程结构,DL4J的版本在父工程dlfj-tutorials中定义为1.0.0-beta7
  • 本篇的代码虽然还是dlfj-tutorials的子工程,但是DL4J版本却使用了更低的1.0.0-beta6,之所以这么做,是因为下一篇文章,咱们会把本篇的训练和测试工作交给GPU来完成,而对应的CUDA库只有1.0.0-beta6
  • 扯了这么多,可以开始编码了

源码下载

名称 链接 备注
项目主页 https://github.com/zq2599/blog_demos 该项目在GitHub上的主页
git仓库地址(https) https://github.com/zq2599/blog_demos.git 该项目源码的仓库地址,https协议
git仓库地址(ssh) git@github.com:zq2599/blog_demos.git 该项目源码的仓库地址,ssh协议
  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:

  • dl4j-tutorials文件夹下有多个子工程,本次实战代码在simple-convolution目录下,如下图红框:

编码

  • 在父工程 dl4j-tutorials下新建名为 simple-convolution的子工程,其pom.xml如下,可见这里的dl4j版本被指定为1.0.0-beta6:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>dlfj-tutorials</artifactId>
<groupId>com.bolingcavalry</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion> <artifactId>simple-convolution</artifactId> <properties>
<dl4j-master.version>1.0.0-beta6</dl4j-master.version>
</properties> <dependencies>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency> <dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency> <dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency> <dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
</dependencies>
</project>
  • 接下来按照前面的分析实现代码,已经添加了详细注释,就不再赘述了:
package com.bolingcavalry.convolution;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random; @Slf4j
public class LeNetMNISTReLu { // 存放文件的地址,请酌情修改
// private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist"; private static final String BASE_PATH = "E:\\temp\\202106\\26"; public static void main(String[] args) throws Exception {
// 图片像素高
int height = 28;
// 图片像素宽
int width = 28;
// 因为是黑白图像,所以颜色通道只有一个
int channels = 1;
// 分类结果,0-9,共十种数字
int outputNum = 10;
// 批大小
int batchSize = 54;
// 循环次数
int nEpochs = 1;
// 初始化伪随机数的种子
int seed = 1234; // 随机数工具
Random randNumGen = new Random(seed); log.info("检查数据集文件夹是否存在:{}", BASE_PATH + "/mnist_png"); if (!new File(BASE_PATH + "/mnist_png").exists()) {
log.info("数据集文件不存在,请下载压缩包并解压到:{}", BASE_PATH);
return;
} // 标签生成器,将指定文件的父目录作为标签
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
// 归一化配置(像素值从0-255变为0-1)
DataNormalization imageScaler = new ImagePreProcessingScaler(); // 不论训练集还是测试集,初始化操作都是相同套路:
// 1. 读取图片,数据格式为NCHW
// 2. 根据批大小创建的迭代器
// 3. 将归一化器作为预处理器 log.info("训练集的矢量化操作...");
// 初始化训练集
File trainData = new File(BASE_PATH + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// 拟合数据(实现类中实际上什么也没做)
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler); log.info("测试集的矢量化操作...");
// 初始化测试集,与前面的训练集操作类似
File testData = new File(BASE_PATH + "/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(imageScaler); // same normalization for better results log.info("配置神经网络"); // 在训练中,将学习率配置为随着迭代阶梯性下降
Map<Integer, Double> learningRateSchedule = new HashMap<>();
learningRateSchedule.put(0, 0.06);
learningRateSchedule.put(200, 0.05);
learningRateSchedule.put(600, 0.028);
learningRateSchedule.put(800, 0.0060);
learningRateSchedule.put(1000, 0.001); // 超参数
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
// L2正则化系数
.l2(0.0005)
// 梯度下降的学习率设置
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
// 权重初始化
.weightInit(WeightInit.XAVIER)
// 准备分层
.list()
// 卷积层
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
// 下采样,即池化
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
// 卷积层
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1) // nIn need not specified in later layers
.nOut(50)
.activation(Activation.IDENTITY)
.build())
// 下采样,即池化
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
// 稠密层,即全连接
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
// 输出
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build(); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); // 每十个迭代打印一次损失函数值
net.setListeners(new ScoreIterationListener(10)); log.info("神经网络共[{}]个参数", net.numParams()); long startTime = System.currentTimeMillis();
// 循环操作
for (int i = 0; i < nEpochs; i++) {
log.info("第[{}]个循环", i);
net.fit(trainIter);
Evaluation eval = net.evaluate(testIter);
log.info(eval.stats());
trainIter.reset();
testIter.reset();
}
log.info("完成训练和测试,耗时[{}]毫秒", System.currentTimeMillis()-startTime); // 保存模型
File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
ModelSerializer.writeModel(net, ministModelPath, true);
log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());
}
}
  • 执行上述代码,日志输出如下,训练和测试都顺利完成,准确率达到0.9886:
21:19:15.355 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1110 is 0.18300625613640034
21:19:15.365 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.632 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.642 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - ========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9886
Precision: 0.9885
Recall: 0.9886
F1 Score: 0.9885
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes) =========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
972 0 0 0 0 0 2 2 2 2 | 0 = 0
0 1126 0 3 0 2 1 1 2 0 | 1 = 1
1 1 1019 2 0 0 0 6 3 0 | 2 = 2
0 0 1 1002 0 5 0 1 1 0 | 3 = 3
0 0 2 0 971 0 3 2 1 3 | 4 = 4
0 0 0 3 0 886 2 1 0 0 | 5 = 5
6 2 0 1 1 5 942 0 1 0 | 6 = 6
0 1 6 0 0 0 0 1015 1 5 | 7 = 7
1 0 1 1 0 2 0 2 962 5 | 8 = 8
1 2 1 3 5 3 0 2 1 991 | 9 = 9 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
21:19:16.643 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 完成训练和测试,耗时[27467]毫秒
21:19:17.019 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 最新的MINIST模型保存在[E:\temp\202106\26\minist-model.zip] Process finished with exit code 0

关于准确率

  • 前面的测试结果显示准确率为0.9886,这是1.0.0-beta6版本DL4J的训练结果,如果换成1.0.0-beta7,准确率可以达到0.99以上,您可以尝试一下;

  • 至此,DL4J框架下的经典卷积实战就完成了,截止目前,咱们的训练和测试工作都是CPU完成的,工作中CPU使用率的上升十分明显,下一篇文章,咱们把今天的工作交给GPU执行试试,看能否借助CUDA加速训练和测试工作;

你不孤单,欣宸原创一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 数据库+中间件系列
  6. DevOps系列

欢迎关注公众号:程序员欣宸

微信搜索「程序员欣宸」,我是欣宸,期待与您一同畅游Java世界...

https://github.com/zq2599/blog_demos

DL4J实战之三:经典卷积实例(LeNet-5)的更多相关文章

  1. 经典卷积网络模型 — LeNet模型笔记

    LeNet-5包含于输入层在内的8层深度卷积神经网络.其中卷积层可以使得原信号特征增强,并且降低噪音.而池化层利用图像相关性原理,对图像进行子采样,可以减少参数个数,减少模型的过拟合程度,同时也可以保 ...

  2. TensorFlow实战之实现AlexNet经典卷积神经网络

    本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.AlexNet模型及其基本原理阐述 1.关于AlexNet ...

  3. 经典卷积神经网络(LeNet、AlexNet、VGG、GoogleNet、ResNet)的实现(MXNet版本)

    卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 其中 文章 详解卷 ...

  4. 五大经典卷积神经网络介绍:LeNet / AlexNet / GoogLeNet / VGGNet/ ResNet

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! LeNet / AlexNet / GoogLeNet / VGG ...

  5. 经典卷积神经网络结构——LeNet-5、AlexNet、VGG-16

    经典卷积神经网络的结构一般满足如下表达式: 输出层 -> (卷积层+ -> 池化层?)+  -> 全连接层+ 上述公式中,“+”表示一个或者多个,“?”表示一个或者零个,如“卷积层+ ...

  6. Flink的sink实战之三:cassandra3

    欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...

  7. kubebuilder实战之三:基础知识速览

    欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...

  8. [原创].NET 分布式架构开发实战之三 数据访问深入一点的思考

    原文:[原创].NET 分布式架构开发实战之三 数据访问深入一点的思考 .NET 分布式架构开发实战之三 数据访问深入一点的思考 前言:首先,感谢园子里的朋友对文章的支持,感谢大家,希望本系列的文章能 ...

  9. 超多经典 canvas 实例,动态离子背景、移动炫彩小球、贪吃蛇、坦克大战、是男人就下100层、心形文字等等等

    超多经典 canvas 实例 普及:<canvas> 元素用于在网页上绘制图形.这是一个图形容器,您可以控制其每一像素,必须使用脚本来绘制图形. 注意:IE 8 以及更早的版本不支持 &l ...

  10. Zookeeper原理和实战开发经典视频教程 百度云网盘下载

    Zookeeper原理和实战开发 经典视频教程 百度云网盘下载 资源下载地址:http://pan.baidu.com/s/1o7ZjPeM   密码:r5yf   

随机推荐

  1. iOS上传文件代码,自定义组装body

    以下代码为上传文件所用代码,简单方便,搞了好久,终于知道这么简单的方式来上传. 其它类库也就是把这几句代码封装的乱七八糟得,让你老久搞不懂原理.不就是在body上面加点字符串,body下面加点字符串, ...

  2. Android Studio下SlidingMenu的导入与基础使用

    一.关于这个控件,其实我们现在很多app都在用,最简单的,你打开QQ,当看资料卡的时候,首先要侧拉一下,那个就是SlidingMenu 这几天查了很多资料,各种方法都试了,但是一直都没有成功,最后在一 ...

  3. DQL、DML、DDL、DCL的概念与区别

    SQL(Structure Query Language)语言是数据库的核心语言. SQL的发展是从1974年开始的,其发展过程如下:1974年-----由Boyce和Chamberlin提出,当时称 ...

  4. Linux 信号详解一(signal函数)

    信号列表 SIGABRT 进程停止运行 SIGALRM 警告钟 SIGFPE 算述运算例外 SIGHUP 系统挂断 SIGILL 非法指令 SIGINT 终端中断 SIGKILL 停止进程(此信号不能 ...

  5. Java多线程——Semaphore信号灯

    Semaphore可以维护当前访问自身的线程个数,并提供了同步机制.使用Semaphore可以控制同时访问资源的线程个数(即允许n个任务同时访问这个资源),例如,实现一个文件允许的并发访问数. Sem ...

  6. Hibernate关联关系映射

    1.  Hibernate关联关系映射 1.1.  one to one <class name="Person"> <id name="id" ...

  7. DAC,MAC和SELinux,SEAndroid

    1. 被ROOT了怎么办 2. SELinux 3. SEAndroid 4. JB(4.3) MR2的漏洞弥补 ------------------------------------------- ...

  8. 解决CUDA driver version is insufficient for CUDA runtime version

    问题 在服务器上安装mxne的GPU版本 sudo pip install mxnet-cu80==1.2.1 然后在gpu上创建数据 import mxnet as mx mx.nd.array([ ...

  9. InnoDB的关键特性-插入缓存,两次写,自适应hash索引

    InnoDB存储引擎的关键特性包括插入缓冲.两次写(double write).自适应哈希索引(adaptive hash index).这些特性为InnoDB存储引擎带来了更好的性能和更高的可靠性. ...

  10. Swift: 用Alamofire做http请求,用ObjectMapper解析JSON

    跟不上时代的人突然间走在了时代的前列,果然有别样的风景.首先鄙视一下AFNetworking.这个东西实在太难用了.不想封装都不行,要不写一大堆代码. NSURL *URL = [NSURL URLW ...