接上篇。

在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

Net

首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
List<DoubleMatrix> bs = new ArrayList<>();
List<ActivationFunction> activations = new ArrayList<>();
CostFunction costFunc;
CostFunction accuracyFunc;
int[] nodesNum;
int layersNum; public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(this.weights,this.bs);
}

函数getCompact()生成对应的超矩阵。

DataProvider

DataProvider是数据的提供者。

public interface DataProvider {
DoubleMatrix getInput();
DoubleMatrix getTarget();
}

如果输入为向量,还包含一个向量字典。

public interface DictDataProvider extends DataProvider {
public DoubleMatrix getIndexs();
public DoubleMatrix getDict();
}

每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

int batchSize;
int dataLen;
DataProvider originalProvider;
List<Integer> endPositions;
List<DataProvider> providers; public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
super();
this.batchSize = batchSize;
this.originalProvider = originalProvider;
this.dataLen = this.originalProvider.getTarget().columns;
this.initEndPositions();
this.initProviders();
} public BatchDataProviderFactory(DataProvider originalProvider) {
this(4, originalProvider);
} public List<DataProvider> getProviders() {
return providers;
}

batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

Propagation

Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

public interface Propagation {
public PropagationResult propagate(Net net,DataProvider provider);
}

传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

BasePropagation实现了该接口,实现了简单的反向传播:

public class BasePropagation implements Propagation{

	// 多个样本。
protected ForwardResult forward(Net net,DoubleMatrix input) { ForwardResult result = new ForwardResult();
result.input = input;
DoubleMatrix currentResult = input;
int index = -1;
for (DoubleMatrix weight : net.weights) {
index++;
DoubleMatrix b = net.bs.get(index);
final ActivationFunction activation = net.activations
.get(index);
currentResult = weight.mmul(currentResult).addColumnVector(b);
result.netResult.add(currentResult); // 乘以导数
DoubleMatrix derivative = activation.derivativeAt(currentResult);
result.derivativeResult.add(derivative); currentResult = activation.valueAt(currentResult);
result.finalResult.add(currentResult); } result.netResult=null;// 不再需要。 return result;
} // 多个样本梯度平均值。
protected BackwardResult backward(Net net,DoubleMatrix target,
ForwardResult forwardResult) {
BackwardResult result = new BackwardResult(); DoubleMatrix output = forwardResult.getOutput();
DoubleMatrix outputDerivative = forwardResult.getOutputDerivative(); result.cost = net.costFunc.valueAt(output, target);
DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
if (net.accuracyFunc != null) {
result.accuracy=net.accuracyFunc.valueAt(output, target);
} result.deltas.add(outputDelta);
for (int i = net.layersNum - 1; i >= 0; i--) {
DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1); // 梯度计算,取所有样本平均
DoubleMatrix layerInput = i == 0 ? forwardResult.input
: forwardResult.finalResult.get(i - 1);
DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
target.columns);
result.gradients.add(gradient);
// 偏置梯度
result.biasGradients.add(pdelta.rowMeans()); // 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
if (i > 0)
delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
result.deltas.add(delta);
}
Collections.reverse(result.gradients);
Collections.reverse(result.biasGradients); //其它的delta都不需要。
DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
result.deltas.clear();
result.deltas.add(inputDeltas); return result;
} @Override
public PropagationResult propagate(Net net, DataProvider provider) {
ForwardResult forwardResult=this.forward(net, provider.getInput());
BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
PropagationResult result=new PropagationResult(backwardResult);
result.output=forwardResult.getOutput();
return result;
}

我们定义的PropagationResult略为:

public class PropagationResult{
DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength
DoubleMatrix cost;// 误差矩阵:1*sampleLength
DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength
private List<DoubleMatrix> gradients;// 权重梯度矩阵
private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵
DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(gradients,biasGradients);
} }

另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

Trainer

Trainer接口定义为:

public interface Trainer {
public void train(Net net,DataProvider provider);
}

简单的实现类为:

public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
Propagation propagation;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
public void trainOne(Net net, DataProvider provider) {
PropagationResult propResult = this.propagation
.propagate(net, provider);
learner.learn(net, propResult, provider); Double cost = propResult.getMeanCost();
Double accuracy = propResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
System.out.println("echops:"+i);
this.trainOne(net, provider);
} }
}

简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

Learner

Learner根据每次传播结果对网络权重进行调整,接口定义如下:

public interface Learner<N extends Net,P extends DataProvider> {
public void learn(N net,PropagationResult propResult,P provider);
}

一个简单的根据动量因子-自适应学习率进行调整的实现类为:

public class MomentAdaptLearner<N extends Net, P extends DataProvider>
implements Learner<N, P> {
double moment = 0.7;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta = eta;
double currentMoment = moment;
CompactDoubleMatrix preGradient; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta = eta;
this.currentMoment = moment;
} public MomentAdaptLearner() { } @Override
public void learn(N net, PropagationResult propResult, P provider) {
if (this.preGradient == null)
init(net, propResult, provider); double cost = propResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:" + this.currentEta);
System.out.println("current moment:" + this.currentMoment);
this.updateGradient(net, propResult, provider); } public void updateGradient(N net, PropagationResult propResult, P provider) {
CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
provider);
CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
propResult, provider);
gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
preGradient.mul(currentMoment));
netCompact.subi(gradCompact);
this.preGradient = gradCompact;
} public CompactDoubleMatrix getNetCompact(N net,
PropagationResult propResult, P provider) {
return net.getCompact();
} public CompactDoubleMatrix getGradientCompact(N net,
PropagationResult propResult, P provider) {
return propResult.getCompact();
} public void modifyParameter(double cost) { if (this.currentEta > 10) {
this.currentEta = 10;
} else if (this.currentEta < 0.0001) {
this.currentEta = 0.0001;
} else if (cost < this.preCost) {
this.currentEta *= 1.05;
this.currentMoment = moment;
} else if (cost < 1.04 * this.preCost) {
this.currentEta *= 0.7;
this.currentMoment *= 0.7;
} else {
this.currentEta = eta;
this.currentMoment = 0.1;
}
this.preCost = cost;
} public void init(Net net, PropagationResult propResult, P provider) {
PropagationResult pResult = new PropagationResult(net);
preGradient = pResult.getCompact().dup();
} }

在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

public class DictMomentLearner extends
MomentAdaptLearner<Net, DictDataProvider> { public DictMomentLearner(double moment, double eta) {
super(moment, eta);
} public DictMomentLearner() {
super();
} @Override
public CompactDoubleMatrix getNetCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getNetCompact(net, propResult,
provider);
result.append(provider.getDict());
return result;
} @Override
public CompactDoubleMatrix getGradientCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
provider);
result.append(DictUtil.getDictGradient(provider, propResult));
return result;
} @Override
public void init(Net net, PropagationResult propResult,
DictDataProvider provider) {
DoubleMatrix preDictGradient = DoubleMatrix.zeros(
provider.getDict().rows, provider.getDict().columns);
super.init(net, propResult, provider);
this.preGradient.append(preDictGradient);
}
}

用java写bp神经网络(四)的更多相关文章

  1. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  2. 用java写bp神经网络(三)

    孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...

  3. 用java写bp神经网络(二)

    接上篇. Net和Propagation具备后,我们就可以训练了.训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量):什么时候调用学习师根据训练的结果进 ...

  4. python手写bp神经网络实现人脸性别识别1.0

    写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...

  5. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  6. java写卷积神经网络---CupCnn简介

    https://blog.csdn.net/u011913612/article/details/79253450

  7. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  8. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  9. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

随机推荐

  1. XE3随想14:关于 SO 与 SA 函数

    通过 SuperObject 的公用函数 SO 实现一个 ISuperObject 接口非常方便; 前面都是给它一个字符串参数, 它的参数可以是任一类型甚至是常数数组. SA 和 SO 都是返回一 I ...

  2. 【JavsScript】webapp的优化整理

    单页or多页 webapp 现状 优劣之分 网络传输优化 综述 fake页-首屏加速 降低请求数 降低请求量 缓存Ajax/localstorage DOM操作优化 综述 关于页面渲染 减少使用定位属 ...

  3. 用JavaScript探测页面上的广告是否被AdBlock屏蔽了的方法

    每个人都讨厌广告.看电视.看电影.看优酷.看网页时,对满天飞的广告也是深恶痛绝.广告是一个不招人喜欢的东西.但是,对一个中小网站站长/博客主来说,广告几乎是唯一的能成支持网站/博客正常运转的资金来源. ...

  4. Android的UI两大基石

        说到Android的UI就不得不从一切的开始View开始说.     让我们从Android Developer上的View的Overview和UI Overview来开始吧.     Cla ...

  5. CodeForces 70

    题目 A题 #include<bits/stdc++.h> using namespace std; int n,b,sum; int main(){ scanf("%d&quo ...

  6. SQL SERVER-时间戳(timestamp)与时间格式(datetime)互相转换

    SQL里面有个DATEADD的函数.时间戳就是一个从1970-01-01 08:00:00到时间的相隔的秒数.所以只要把这个时间戳加上1970-01-01 08:00:00这个时间就可以得到你想要的时 ...

  7. python - 闭包,迭代器

    一.第一类对象 1.函数名的运用     函数名是一个变量,但它是一个特殊的变量,与括号配合可以执行函数的变量     1.函数名的内存地址 def func1(): print('你是谁,你来自哪里 ...

  8. CSS学习之路,指定值,计算值,使用值。

    前面被问过这几个值得区别,没太研究,有点抠文字的感觉,既然到这儿了 ,就简答梳理下吧. 指定值(specified value):通过样式表样式规则定义的值:可以来自层叠样式表,如果没有指定,则考虑父 ...

  9. $.post() 和 $.get() 如何同步请求

    由于$.post() 和 $.get() 默认是 异步请求,如果需要同步请求,则可以进行如下使用: 在$.post()前把ajax设置为同步:$.ajaxSettings.async = false; ...

  10. DQL完整语法及示例

    DQL:Data Query Language,数据查询语言,其实它也是DML(数据库操作语言的一种),下面看一看完整的语法: 注意,关键字建议大写,不带[ ]是必需的,带[ ]是可选的. SELEC ...