#BAEL-18260 Attempt to update deeplearning4j version (the two examples used two different versions)

This commit is contained in:
Alessio Stalla
2019-10-13 21:30:33 +02:00
parent 3a7c2ac8cc
commit 104f935120

View File

@@ -3,15 +3,16 @@ package com.baeldung.deeplearning4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@@ -19,6 +20,7 @@ import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException;
@@ -49,11 +51,11 @@ public class IrisClassifier {
DataSet testData = testAndTrain.getTest();
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.iterations(1000)
.maxNumLineSearchIterations(1000)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(0.0001)
//.regularization(true)
.l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3)
.build())
@@ -62,14 +64,15 @@ public class IrisClassifier {
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(CLASSES_COUNT).build())
.backprop(true).pretrain(false)
.backpropType(BackpropType.Standard)//.pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
NetworkUtils.setLearningRate(model, 0.1);
model.fit(trainingData);
INDArray output = model.output(testData.getFeatureMatrix());
INDArray output = model.output(testData.getFeatures());
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(testData.getLabels(), output);