-
Notifications
You must be signed in to change notification settings - Fork 3
/
Regression.Java
167 lines (144 loc) · 7.12 KB
/
Regression.Java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package org.deeplearning4j.examples.feedforward.regression;
// importing functions for Mathematical vector operations
import org.deeplearning4j.examples.feedforward.regression.function.*;
// for activation functions (ReLU for instance)
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.util.Collections;
import java.util.List;
import java.util.Random;
/**Example: Train a network to reproduce certain mathematical functions, and plot the results.
* Plotting of the network output occurs every 'plotFrequency' epochs. Thus, the plot shows the accuracy of the network
* predictions as training progresses.
* A number of mathematical functions are implemented here.
* Note the use of the identity function on the network output layer, for regression
*
* @author Alex Black
*/
public class RegressionMathFunctions {
//Random number generator seed, for reproducability
public static final int seed = 12345;
//Number of iterations per minibatch
public static final int iterations = 1;
//Number of epochs (full passes of the data)
public static final int nEpochs = 2000;
//How frequently should we plot the network output?
public static final int plotFrequency = 500;
//Number of data points
public static final int nSamples = 1000;
//Batch size: i.e., each epoch has nSamples/batchSize parameter updates
public static final int batchSize = 100;
//Network learning rate
public static final double learningRate = 0.01;
public static final Random rng = new Random(seed);
public static final int numInputs = 1;
public static final int numOutputs = 1;
public static void main(final String[] args){
//Switch these two options to do different functions with different networks
final MathFunction fn = new SinXDivXMathFunction();
final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration();
//Generate the training data
final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng);
//Create the network
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
//Train the network on the full data set, and evaluate in periodically
final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
for( int i=0; i<nEpochs; i++ ){
iterator.reset();
net.fit(iterator);
if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false);
}
//Plot the target data and the network predictions
plot(fn,x,fn.getFunctionValues(x),networkPredictions);
}
/** Returns the network configuration, 2 hidden DenseLayers of size 50.
*/
private static MultiLayerConfiguration getDeepDenseLayerNetworkConfiguration() {
final int numHiddenNodes = 50;
return new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.9))
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
}
/** Create a DataSetIterator for training
* @param x X values
* @param function Function to evaluate
* @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
* @param rng Random number generator (for repeatability)
*/
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
final INDArray y = function.getFunctionValues(x);
final DataSet allData = new DataSet(x,y);
final List<DataSet> list = allData.asList();
Collections.shuffle(list,rng);
return new ListDataSetIterator(list,batchSize);
}
//Plot the data
private static void plot(final MathFunction function, final INDArray x, final INDArray y, final INDArray... predicted) {
final XYSeriesCollection dataSet = new XYSeriesCollection();
addSeries(dataSet,x,y,"True Function (Labels)");
for( int i=0; i<predicted.length; i++ ){
addSeries(dataSet,x,predicted[i],String.valueOf(i));
}
final JFreeChart chart = ChartFactory.createXYLineChart(
"Regression Example - " + function.getName(), // chart title
"X", // x axis label
function.getName() + "(X)", // y axis label
dataSet, // data
PlotOrientation.VERTICAL,
true, // include legend
true, // tooltips
false // urls
);
final ChartPanel panel = new ChartPanel(chart);
final JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setVisible(true);
}
private static void addSeries(final XYSeriesCollection dataSet, final INDArray x, final INDArray y, final String label){
final double[] xd = x.data().asDouble();
final double[] yd = y.data().asDouble();
final XYSeries s = new XYSeries(label);
for( int j=0; j<xd.length; j++ ) s.add(xd[j],yd[j]);
dataSet.addSeries(s);
}
}