diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index d88c877aa4..e1e4998c98 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -40,12 +40,12 @@ org.slf4j slf4j-api - 1.7.5 + ${sl4j.version} org.slf4j slf4j-log4j12 - 1.7.5 + ${sl4j.version} @@ -63,6 +63,7 @@ 0.9.1 4.3.5 + 1.7.5 diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java similarity index 79% rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java index cb69d0c818..70348a6d9e 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java @@ -1,4 +1,4 @@ -package com.baeldung.deeplearning4j.cnn.service.dataset; +package com.baeldung.deeplearning4j.cnn; import lombok.Getter; import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator; @@ -8,18 +8,19 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.List; @Getter -public class CifarDataSetService implements IDataSetService { +class CifarDataSetService implements IDataSetService { - private CifarDataSetIterator trainIterator; - private CifarDataSetIterator testIterator; - - private final InputType inputType = InputType.convolutional(32,32,3); + private final InputType inputType = InputType.convolutional(32, 32, 3); private final int trainImagesNum = 512; private final int testImagesNum = 128; private final int trainBatch = 16; private final int testBatch = 8; - public CifarDataSetService() { + private final CifarDataSetIterator trainIterator; + + private final CifarDataSetIterator testIterator; + + CifarDataSetService() { trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true); testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false); } diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java index 2e2d4392b8..224062c388 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java @@ -1,14 +1,11 @@ package com.baeldung.deeplearning4j.cnn; -import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel; -import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties; -import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.eval.Evaluation; @Slf4j -public class CnnExample { +class CnnExample { public static void main(String... args) { CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties()); diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java similarity index 86% rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java index 037d14529c..bd87709c0e 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java @@ -1,6 +1,5 @@ -package com.baeldung.deeplearning4j.cnn.domain.network; +package com.baeldung.deeplearning4j.cnn; -import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -17,15 +16,15 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.stream.IntStream; @Slf4j -public class CnnModel { +class CnnModel { private final IDataSetService dataSetService; - private MultiLayerNetwork network; + private final MultiLayerNetwork network; private final CnnModelProperties properties; - public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) { + CnnModel(IDataSetService dataSetService, CnnModelProperties properties) { this.dataSetService = dataSetService; this.properties = properties; @@ -52,17 +51,17 @@ public class CnnModel { network = new MultiLayerNetwork(configuration); } - public void train() { + void train() { network.init(); int epochsNum = properties.getEpochsNum(); IntStream.range(1, epochsNum + 1).forEach(epoch -> { - log.info(String.format("Epoch %d?%d", epoch, epochsNum)); + log.info("Epoch {} / {}", epoch, epochsNum); network.fit(dataSetService.trainIterator()); }); } - public Evaluation evaluate() { - return network.evaluate(dataSetService.testIterator()); + Evaluation evaluate() { + return network.evaluate(dataSetService.testIterator()); } private ConvolutionLayer conv5x5() { @@ -84,7 +83,7 @@ public class CnnModel { } private ConvolutionLayer conv3x3Stride1Padding2() { - return new ConvolutionLayer.Builder(3, 3) + return new ConvolutionLayer.Builder(3, 3) .nOut(32) .stride(1, 1) .padding(2, 2) @@ -95,7 +94,7 @@ public class CnnModel { private SubsamplingLayer pooling2x2Stride1() { return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2,2) + .kernelSize(2, 2) .stride(1, 1) .build(); } diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java similarity index 70% rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java index 7ea3a71363..d010d789c8 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java @@ -1,10 +1,10 @@ -package com.baeldung.deeplearning4j.cnn.domain.network; +package com.baeldung.deeplearning4j.cnn; import lombok.Value; import org.deeplearning4j.nn.conf.Updater; @Value -public class CnnModelProperties { +class CnnModelProperties { private final int epochsNum = 512; private final double learningRate = 0.001; diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java similarity index 74% rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java index c27e566076..ea88bf550c 100644 --- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java @@ -1,11 +1,11 @@ -package com.baeldung.deeplearning4j.cnn.service.dataset; +package com.baeldung.deeplearning4j.cnn; import org.deeplearning4j.nn.conf.inputs.InputType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.List; -public interface IDataSetService { +interface IDataSetService { DataSetIterator trainIterator(); DataSetIterator testIterator();