#BAEL-18260 Attempt to update deeplearning4j version (the two examples used two different versions)
This commit is contained in:
@@ -3,15 +3,16 @@ package com.baeldung.deeplearning4j;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.util.ClassPathResource;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.util.NetworkUtils;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
@@ -19,6 +20,7 @@ import org.nd4j.linalg.dataset.SplitTestAndTrain;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -49,11 +51,11 @@ public class IrisClassifier {
|
||||
DataSet testData = testAndTrain.getTest();
|
||||
|
||||
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
|
||||
.iterations(1000)
|
||||
.maxNumLineSearchIterations(1000)
|
||||
.activation(Activation.TANH)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.learningRate(0.1)
|
||||
.regularization(true).l2(0.0001)
|
||||
//.regularization(true)
|
||||
.l2(0.0001)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3)
|
||||
.build())
|
||||
@@ -62,14 +64,15 @@ public class IrisClassifier {
|
||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(3).nOut(CLASSES_COUNT).build())
|
||||
.backprop(true).pretrain(false)
|
||||
.backpropType(BackpropType.Standard)//.pretrain(false)
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
|
||||
model.init();
|
||||
NetworkUtils.setLearningRate(model, 0.1);
|
||||
model.fit(trainingData);
|
||||
|
||||
INDArray output = model.output(testData.getFeatureMatrix());
|
||||
INDArray output = model.output(testData.getFeatures());
|
||||
|
||||
Evaluation eval = new Evaluation(CLASSES_COUNT);
|
||||
eval.eval(testData.getLabels(), output);
|
||||
|
||||
Reference in New Issue
Block a user