From 6e1cab0051109d09cf1869c2869675a42c3cbac0 Mon Sep 17 00:00:00 2001 From: helga_sh Date: Tue, 21 Jul 2020 16:24:31 +0300 Subject: [PATCH] CNN example with Deeplearning4j in Java --- deeplearning4j/pom.xml | 10 ++ .../deeplearning4j/cnn/CnnExample.java | 21 +++ .../cnn/domain/network/CnnModel.java | 120 ++++++++++++++++++ .../domain/network/CnnModelProperties.java | 13 ++ .../service/dataset/CifarDataSetService.java | 46 +++++++ .../cnn/service/dataset/IDataSetService.java | 16 +++ 6 files changed, 226 insertions(+) create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index c8fa18cbd4..d88c877aa4 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -37,6 +37,16 @@ deeplearning4j-nn ${dl4j.version} + + org.slf4j + slf4j-api + 1.7.5 + + + org.slf4j + slf4j-log4j12 + 1.7.5 + org.datavec diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java new file mode 100644 index 0000000000..2e2d4392b8 --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java @@ -0,0 +1,21 @@ +package com.baeldung.deeplearning4j.cnn; + + +import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel; +import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties; +import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.eval.Evaluation; + +@Slf4j +public class CnnExample { + + public static void main(String... args) { + CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties()); + + network.train(); + Evaluation evaluation = network.evaluate(); + + log.info(evaluation.stats()); + } +} diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java new file mode 100644 index 0000000000..037d14529c --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java @@ -0,0 +1,120 @@ +package com.baeldung.deeplearning4j.cnn.domain.network; + +import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.eval.Evaluation; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.util.stream.IntStream; + +@Slf4j +public class CnnModel { + + private final IDataSetService dataSetService; + + private MultiLayerNetwork network; + + private final CnnModelProperties properties; + + public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) { + + this.dataSetService = dataSetService; + this.properties = properties; + + MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + .seed(1611) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .learningRate(properties.getLearningRate()) + .regularization(true) + .updater(properties.getOptimizer()) + .list() + .layer(0, conv5x5()) + .layer(1, pooling2x2Stride2()) + .layer(2, conv3x3Stride1Padding2()) + .layer(3, pooling2x2Stride1()) + .layer(4, conv3x3Stride1Padding1()) + .layer(5, pooling2x2Stride1()) + .layer(6, dense()) + .pretrain(false) + .backprop(true) + .setInputType(dataSetService.inputType()) + .build(); + + network = new MultiLayerNetwork(configuration); + } + + public void train() { + network.init(); + int epochsNum = properties.getEpochsNum(); + IntStream.range(1, epochsNum + 1).forEach(epoch -> { + log.info(String.format("Epoch %d?%d", epoch, epochsNum)); + network.fit(dataSetService.trainIterator()); + }); + } + + public Evaluation evaluate() { + return network.evaluate(dataSetService.testIterator()); + } + + private ConvolutionLayer conv5x5() { + return new ConvolutionLayer.Builder(5, 5) + .nIn(3) + .nOut(16) + .stride(1, 1) + .padding(1, 1) + .weightInit(WeightInit.XAVIER_UNIFORM) + .activation(Activation.RELU) + .build(); + } + + private SubsamplingLayer pooling2x2Stride2() { + return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) + .kernelSize(2, 2) + .stride(2, 2) + .build(); + } + + private ConvolutionLayer conv3x3Stride1Padding2() { + return new ConvolutionLayer.Builder(3, 3) + .nOut(32) + .stride(1, 1) + .padding(2, 2) + .weightInit(WeightInit.XAVIER_UNIFORM) + .activation(Activation.RELU) + .build(); + } + + private SubsamplingLayer pooling2x2Stride1() { + return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) + .kernelSize(2,2) + .stride(1, 1) + .build(); + } + + private ConvolutionLayer conv3x3Stride1Padding1() { + return new ConvolutionLayer.Builder(3, 3) + .nOut(64) + .stride(1, 1) + .padding(1, 1) + .weightInit(WeightInit.XAVIER_UNIFORM) + .activation(Activation.RELU) + .build(); + } + + private OutputLayer dense() { + return new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR) + .activation(Activation.SOFTMAX) + .weightInit(WeightInit.XAVIER_UNIFORM) + .nOut(dataSetService.labels().size() - 1) + .build(); + } +} diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java new file mode 100644 index 0000000000..7ea3a71363 --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java @@ -0,0 +1,13 @@ +package com.baeldung.deeplearning4j.cnn.domain.network; + +import lombok.Value; +import org.deeplearning4j.nn.conf.Updater; + +@Value +public class CnnModelProperties { + private final int epochsNum = 512; + + private final double learningRate = 0.001; + + private final Updater optimizer = Updater.ADAM; +} diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java new file mode 100644 index 0000000000..cb69d0c818 --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java @@ -0,0 +1,46 @@ +package com.baeldung.deeplearning4j.cnn.service.dataset; + +import lombok.Getter; +import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; + +import java.util.List; + +@Getter +public class CifarDataSetService implements IDataSetService { + + private CifarDataSetIterator trainIterator; + private CifarDataSetIterator testIterator; + + private final InputType inputType = InputType.convolutional(32,32,3); + private final int trainImagesNum = 512; + private final int testImagesNum = 128; + private final int trainBatch = 16; + private final int testBatch = 8; + + public CifarDataSetService() { + trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true); + testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false); + } + + @Override + public DataSetIterator trainIterator() { + return trainIterator; + } + + @Override + public DataSetIterator testIterator() { + return testIterator; + } + + @Override + public InputType inputType() { + return inputType; + } + + @Override + public List labels() { + return trainIterator.getLabels(); + } +} diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java new file mode 100644 index 0000000000..c27e566076 --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java @@ -0,0 +1,16 @@ +package com.baeldung.deeplearning4j.cnn.service.dataset; + +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; + +import java.util.List; + +public interface IDataSetService { + DataSetIterator trainIterator(); + + DataSetIterator testIterator(); + + InputType inputType(); + + List labels(); +}