diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java index bf341209e1..b9c794650e 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java @@ -3,15 +3,16 @@ package com.baeldung.deeplearning4j; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.eval.Evaluation; +import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.util.NetworkUtils; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -19,6 +20,7 @@ import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; @@ -49,11 +51,11 @@ public class IrisClassifier { DataSet testData = testAndTrain.getTest(); MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() - .iterations(1000) + .maxNumLineSearchIterations(1000) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) - .learningRate(0.1) - .regularization(true).l2(0.0001) + //.regularization(true) + .l2(0.0001) .list() .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3) .build()) @@ -62,14 +64,15 @@ public class IrisClassifier { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) .nIn(3).nOut(CLASSES_COUNT).build()) - .backprop(true).pretrain(false) + .backpropType(BackpropType.Standard)//.pretrain(false) .build(); MultiLayerNetwork model = new MultiLayerNetwork(configuration); model.init(); + NetworkUtils.setLearningRate(model, 0.1); model.fit(trainingData); - INDArray output = model.output(testData.getFeatureMatrix()); + INDArray output = model.output(testData.getFeatures()); Evaluation eval = new Evaluation(CLASSES_COUNT); eval.eval(testData.getLabels(), output);