From c20918329f1ff4c03f0700c0a525603f9ec7a1ad Mon Sep 17 00:00:00 2001 From: Andrew Shcherbakov Date: Wed, 4 Sep 2019 22:25:22 +0200 Subject: [PATCH] Refactor the utility class --- .../com/baeldung/logreg/DataUtilities.java | 102 ----------------- .../com/baeldung/logreg/MnistClassifier.java | 16 +-- .../com/baeldung/logreg/MnistPrediction.java | 10 +- .../main/java/com/baeldung/logreg/Utils.java | 103 ++++++++++++++++++ 4 files changed, 119 insertions(+), 112 deletions(-) delete mode 100644 ml/src/main/java/com/baeldung/logreg/DataUtilities.java create mode 100644 ml/src/main/java/com/baeldung/logreg/Utils.java diff --git a/ml/src/main/java/com/baeldung/logreg/DataUtilities.java b/ml/src/main/java/com/baeldung/logreg/DataUtilities.java deleted file mode 100644 index 2f18d30219..0000000000 --- a/ml/src/main/java/com/baeldung/logreg/DataUtilities.java +++ /dev/null @@ -1,102 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package com.baeldung.logreg; - -import org.apache.commons.compress.archivers.tar.TarArchiveEntry; -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; -import org.apache.http.HttpEntity; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.client.methods.HttpGet; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClientBuilder; - -import java.io.*; - -/** - * Common data utility functions. - * - * @author fvaleri - */ -public class DataUtilities { - - /** - * Download a remote file if it doesn't exist. - * @param remoteUrl URL of the remote file. - * @param localPath Where to download the file. - * @return True if and only if the file has been downloaded. - * @throws Exception IO error. - */ - public static boolean downloadFile(String remoteUrl, String localPath) throws IOException { - boolean downloaded = false; - if (remoteUrl == null || localPath == null) - return downloaded; - File file = new File(localPath); - if (!file.exists()) { - file.getParentFile().mkdirs(); - HttpClientBuilder builder = HttpClientBuilder.create(); - CloseableHttpClient client = builder.build(); - try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) { - HttpEntity entity = response.getEntity(); - if (entity != null) { - try (FileOutputStream outstream = new FileOutputStream(file)) { - entity.writeTo(outstream); - outstream.flush(); - outstream.close(); - } - } - } - downloaded = true; - } - if (!file.exists()) - throw new IOException("File doesn't exist: " + localPath); - return downloaded; - } - - /** - * Extract a "tar.gz" file into a local folder. - * @param inputPath Input file path. - * @param outputPath Output directory path. - * @throws IOException IO error. - */ - public static void extractTarGz(String inputPath, String outputPath) throws IOException { - if (inputPath == null || outputPath == null) - return; - final int bufferSize = 4096; - if (!outputPath.endsWith("" + File.separatorChar)) - outputPath = outputPath + File.separatorChar; - try (TarArchiveInputStream tais = new TarArchiveInputStream( - new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) { - TarArchiveEntry entry; - while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) { - if (entry.isDirectory()) { - new File(outputPath + entry.getName()).mkdirs(); - } else { - int count; - byte data[] = new byte[bufferSize]; - FileOutputStream fos = new FileOutputStream(outputPath + entry.getName()); - BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize); - while ((count = tais.read(data, 0, bufferSize)) != -1) { - dest.write(data, 0, count); - } - dest.close(); - } - } - } - } - -} diff --git a/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java b/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java index 395307712d..1246de973f 100644 --- a/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java +++ b/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java @@ -67,18 +67,20 @@ public class MnistClassifier { final String path = basePath + "mnist_png" + File.separator; if (!new File(path).exists()) { - logger.debug("Downloading data {}", dataUrl); + logger.info("Downloading data {}", dataUrl); String localFilePath = basePath + "mnist_png.tar.gz"; - logger.info("local file: {}", localFilePath); - if (DataUtilities.downloadFile(dataUrl, localFilePath)) { - DataUtilities.extractTarGz(localFilePath, basePath); + File file = new File(localFilePath); + if (!file.exists()) { + file.getParentFile() + .mkdirs(); + Utils.downloadAndSave(dataUrl, file); + Utils.extractTarArchive(file, basePath); } } else { - logger.info("local file exists {}", path); - + logger.info("Using the local data from folder {}", path); } - logger.info("Vectorizing data..."); + logger.info("Vectorizing the data from folder {}", path); // vectorization of train data File trainData = new File(path + "training"); FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); diff --git a/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java b/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java index 5ec1348e07..56097d9a45 100644 --- a/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java +++ b/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java @@ -36,18 +36,22 @@ public class MnistPrediction { } public static void main(String[] args) throws IOException { - String path = fileChose().toString(); + if (!modelPath.exists()) { + logger.info("The model not found. Have you trained it?"); + return; + } MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath); + String path = fileChose(); File file = new File(path); INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file); new ImagePreProcessingScaler(0, 1).transform(image); - + // Pass through to neural Net INDArray output = model.output(image); logger.info("File: {}", path); - logger.info(output.toString()); + logger.info("Probabilities: {}", output); } } diff --git a/ml/src/main/java/com/baeldung/logreg/Utils.java b/ml/src/main/java/com/baeldung/logreg/Utils.java new file mode 100644 index 0000000000..fa4be127cd --- /dev/null +++ b/ml/src/main/java/com/baeldung/logreg/Utils.java @@ -0,0 +1,103 @@ +package com.baeldung.logreg; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.apache.commons.compress.archivers.ArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.http.HttpEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for digit classifier. + * + */ +public class Utils { + + private static final Logger logger = LoggerFactory.getLogger(Utils.class); + + private Utils() { + } + + /** + * Download the content of the given url and save it into a file. + * @param url + * @param file + */ + public static void downloadAndSave(String url, File file) throws IOException { + CloseableHttpClient client = HttpClientBuilder.create() + .build(); + logger.info("Connecting to {}", url); + try (CloseableHttpResponse response = client.execute(new HttpGet(url))) { + HttpEntity entity = response.getEntity(); + if (entity != null) { + logger.info("Downloaded {} bytes", entity.getContentLength()); + try (FileOutputStream outstream = new FileOutputStream(file)) { + logger.info("Saving to the local file"); + entity.writeTo(outstream); + outstream.flush(); + logger.info("Local file saved"); + } + } + } + } + + /** + * Extract a "tar.gz" file into a given folder. + * @param file + * @param folder + */ + public static void extractTarArchive(File file, String folder) throws IOException { + logger.info("Extracting archive {} into folder {}", file.getName(), folder); + // @formatter:off + try (FileInputStream fis = new FileInputStream(file); + BufferedInputStream bis = new BufferedInputStream(fis); + GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis); + TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) { + // @formatter:on + TarArchiveEntry entry; + while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) { + extractEntry(entry, tar, folder); + } + } + logger.info("Archive extracted"); + } + + /** + * Extract an entry of the input stream into a given folder + * @param entry + * @param tar + * @param folder + * @throws IOException + */ + public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException { + final int bufferSize = 4096; + final String path = folder + entry.getName(); + if (entry.isDirectory()) { + new File(path).mkdirs(); + } else { + int count; + byte[] data = new byte[bufferSize]; + // @formatter:off + try (FileOutputStream os = new FileOutputStream(path); + BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) { + // @formatter:off + while ((count = tar.read(data, 0, bufferSize)) != -1) { + dest.write(data, 0, count); + } + } + } + } +}