diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml
index d88c877aa4..e1e4998c98 100644
--- a/deeplearning4j/pom.xml
+++ b/deeplearning4j/pom.xml
@@ -40,12 +40,12 @@
org.slf4j
slf4j-api
- 1.7.5
+ ${sl4j.version}
org.slf4j
slf4j-log4j12
- 1.7.5
+ ${sl4j.version}
@@ -63,6 +63,7 @@
0.9.1
4.3.5
+ 1.7.5
diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java
similarity index 79%
rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java
rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java
index cb69d0c818..70348a6d9e 100644
--- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java
+++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java
@@ -1,4 +1,4 @@
-package com.baeldung.deeplearning4j.cnn.service.dataset;
+package com.baeldung.deeplearning4j.cnn;
import lombok.Getter;
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
@@ -8,18 +8,19 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
@Getter
-public class CifarDataSetService implements IDataSetService {
+class CifarDataSetService implements IDataSetService {
- private CifarDataSetIterator trainIterator;
- private CifarDataSetIterator testIterator;
-
- private final InputType inputType = InputType.convolutional(32,32,3);
+ private final InputType inputType = InputType.convolutional(32, 32, 3);
private final int trainImagesNum = 512;
private final int testImagesNum = 128;
private final int trainBatch = 16;
private final int testBatch = 8;
- public CifarDataSetService() {
+ private final CifarDataSetIterator trainIterator;
+
+ private final CifarDataSetIterator testIterator;
+
+ CifarDataSetService() {
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
}
diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java
index 2e2d4392b8..224062c388 100644
--- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java
+++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java
@@ -1,14 +1,11 @@
package com.baeldung.deeplearning4j.cnn;
-import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel;
-import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties;
-import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation;
@Slf4j
-public class CnnExample {
+class CnnExample {
public static void main(String... args) {
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());
diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java
similarity index 86%
rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java
rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java
index 037d14529c..bd87709c0e 100644
--- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java
+++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java
@@ -1,6 +1,5 @@
-package com.baeldung.deeplearning4j.cnn.domain.network;
+package com.baeldung.deeplearning4j.cnn;
-import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@@ -17,15 +16,15 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.stream.IntStream;
@Slf4j
-public class CnnModel {
+class CnnModel {
private final IDataSetService dataSetService;
- private MultiLayerNetwork network;
+ private final MultiLayerNetwork network;
private final CnnModelProperties properties;
- public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
+ CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
this.dataSetService = dataSetService;
this.properties = properties;
@@ -52,17 +51,17 @@ public class CnnModel {
network = new MultiLayerNetwork(configuration);
}
- public void train() {
+ void train() {
network.init();
int epochsNum = properties.getEpochsNum();
IntStream.range(1, epochsNum + 1).forEach(epoch -> {
- log.info(String.format("Epoch %d?%d", epoch, epochsNum));
+ log.info("Epoch {} / {}", epoch, epochsNum);
network.fit(dataSetService.trainIterator());
});
}
- public Evaluation evaluate() {
- return network.evaluate(dataSetService.testIterator());
+ Evaluation evaluate() {
+ return network.evaluate(dataSetService.testIterator());
}
private ConvolutionLayer conv5x5() {
@@ -84,7 +83,7 @@ public class CnnModel {
}
private ConvolutionLayer conv3x3Stride1Padding2() {
- return new ConvolutionLayer.Builder(3, 3)
+ return new ConvolutionLayer.Builder(3, 3)
.nOut(32)
.stride(1, 1)
.padding(2, 2)
@@ -95,7 +94,7 @@ public class CnnModel {
private SubsamplingLayer pooling2x2Stride1() {
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
- .kernelSize(2,2)
+ .kernelSize(2, 2)
.stride(1, 1)
.build();
}
diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java
similarity index 70%
rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java
rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java
index 7ea3a71363..d010d789c8 100644
--- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java
+++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java
@@ -1,10 +1,10 @@
-package com.baeldung.deeplearning4j.cnn.domain.network;
+package com.baeldung.deeplearning4j.cnn;
import lombok.Value;
import org.deeplearning4j.nn.conf.Updater;
@Value
-public class CnnModelProperties {
+class CnnModelProperties {
private final int epochsNum = 512;
private final double learningRate = 0.001;
diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java
similarity index 74%
rename from deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java
rename to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java
index c27e566076..ea88bf550c 100644
--- a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java
+++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java
@@ -1,11 +1,11 @@
-package com.baeldung.deeplearning4j.cnn.service.dataset;
+package com.baeldung.deeplearning4j.cnn;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
-public interface IDataSetService {
+interface IDataSetService {
DataSetIterator trainIterator();
DataSetIterator testIterator();