From dc72b8b39755a78670bf790d451001b331ea67bb Mon Sep 17 00:00:00 2001 From: Kumar Chandrakant Date: Fri, 15 Mar 2019 11:22:27 +0530 Subject: [PATCH] Adding source code for the tutorial tracked under BAEL-2759 (#6533) --- pom.xml | 5 ++ tensorflow-java/.gitignore | 6 +++ tensorflow-java/README.md | 3 ++ tensorflow-java/pom.xml | 52 +++++++++++++++++++ .../baeldung/tensorflow/TensorflowGraph.java | 41 +++++++++++++++ .../tensorflow/TensorflowSavedModel.java | 14 +++++ .../src/main/python/tensorflowGraph.py | 16 ++++++ .../tensorflow/TensorflowGraphUnitTest.java | 21 ++++++++ 8 files changed, 158 insertions(+) create mode 100644 tensorflow-java/.gitignore create mode 100644 tensorflow-java/README.md create mode 100644 tensorflow-java/pom.xml create mode 100644 tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowGraph.java create mode 100644 tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowSavedModel.java create mode 100644 tensorflow-java/src/main/python/tensorflowGraph.py create mode 100644 tensorflow-java/src/test/java/org/baeldung/tensorflow/TensorflowGraphUnitTest.java diff --git a/pom.xml b/pom.xml index 18e3ace31b..826b8eea63 100644 --- a/pom.xml +++ b/pom.xml @@ -526,6 +526,9 @@ rxjava rxjava-2 software-security/sql-injection-samples + + tensorflow-java + @@ -742,6 +745,8 @@ xml xmlunit-2 xstream + + tensorflow-java diff --git a/tensorflow-java/.gitignore b/tensorflow-java/.gitignore new file mode 100644 index 0000000000..eaea64ae48 --- /dev/null +++ b/tensorflow-java/.gitignore @@ -0,0 +1,6 @@ +/.settings +/model +/target +.classpath +.project +.springBeans \ No newline at end of file diff --git a/tensorflow-java/README.md b/tensorflow-java/README.md new file mode 100644 index 0000000000..aac5b7544c --- /dev/null +++ b/tensorflow-java/README.md @@ -0,0 +1,3 @@ +## Relevant articles: + +- [TensorFlow for Java](https://www.baeldung.com/xxxx) diff --git a/tensorflow-java/pom.xml b/tensorflow-java/pom.xml new file mode 100644 index 0000000000..e9d92157e8 --- /dev/null +++ b/tensorflow-java/pom.xml @@ -0,0 +1,52 @@ + + + 4.0.0 + com.baeldung + tensorflow-java + 1.0-SNAPSHOT + jar + http://maven.apache.org + + + com.baeldung + parent-modules + 1.0.0-SNAPSHOT + + + + 1.8 + 1.12.0 + 5.4.0 + + + + + org.tensorflow + tensorflow + ${tensorflow.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.jupiter.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.jupiter.version} + test + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + \ No newline at end of file diff --git a/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowGraph.java b/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowGraph.java new file mode 100644 index 0000000000..a44ef4c4ee --- /dev/null +++ b/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowGraph.java @@ -0,0 +1,41 @@ +package org.baeldung.tensorflow; + +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Operation; +import org.tensorflow.Session; +import org.tensorflow.Tensor; + +public class TensorflowGraph { + + public static Graph createGraph() { + Graph graph = new Graph(); + Operation a = graph.opBuilder("Const", "a").setAttr("dtype", DataType.fromClass(Double.class)) + .setAttr("value", Tensor.create(3.0, Double.class)).build(); + Operation b = graph.opBuilder("Const", "b").setAttr("dtype", DataType.fromClass(Double.class)) + .setAttr("value", Tensor.create(2.0, Double.class)).build(); + Operation x = graph.opBuilder("Placeholder", "x").setAttr("dtype", DataType.fromClass(Double.class)).build(); + Operation y = graph.opBuilder("Placeholder", "y").setAttr("dtype", DataType.fromClass(Double.class)).build(); + Operation ax = graph.opBuilder("Mul", "ax").addInput(a.output(0)).addInput(x.output(0)).build(); + Operation by = graph.opBuilder("Mul", "by").addInput(b.output(0)).addInput(y.output(0)).build(); + graph.opBuilder("Add", "z").addInput(ax.output(0)).addInput(by.output(0)).build(); + return graph; + } + + public static Object runGraph(Graph graph, Double x, Double y) { + Object result; + try (Session sess = new Session(graph)) { + result = sess.runner().fetch("z").feed("x", Tensor.create(x, Double.class)) + .feed("y", Tensor.create(y, Double.class)).run().get(0).expect(Double.class) + .doubleValue(); + } + return result; + } + + public static void main(String[] args) { + Graph graph = TensorflowGraph.createGraph(); + Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0); + System.out.println(result); + graph.close(); + } +} \ No newline at end of file diff --git a/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowSavedModel.java b/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowSavedModel.java new file mode 100644 index 0000000000..4259a787e8 --- /dev/null +++ b/tensorflow-java/src/main/java/org/baeldung/tensorflow/TensorflowSavedModel.java @@ -0,0 +1,14 @@ +package org.baeldung.tensorflow; + +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Tensor; + +public class TensorflowSavedModel { + + public static void main(String[] args) { + SavedModelBundle model = SavedModelBundle.load("./model", "serve"); + Tensor tensor = model.session().runner().fetch("z").feed("x", Tensor.create(3, Integer.class)) + .feed("y", Tensor.create(3, Integer.class)).run().get(0).expect(Integer.class); + System.out.println(tensor.intValue()); + } +} \ No newline at end of file diff --git a/tensorflow-java/src/main/python/tensorflowGraph.py b/tensorflow-java/src/main/python/tensorflowGraph.py new file mode 100644 index 0000000000..ab7f8810ac --- /dev/null +++ b/tensorflow-java/src/main/python/tensorflowGraph.py @@ -0,0 +1,16 @@ +import tensorflow as tf +graph = tf.Graph() +builder = tf.saved_model.builder.SavedModelBuilder('./model') +writer = tf.summary.FileWriter('.') +with graph.as_default(): + a = tf.constant(2, name='a') + b = tf.constant(3, name='b') + x = tf.placeholder(tf.int32, name='x') + y = tf.placeholder(tf.int32, name='y') + z = tf.math.add(a*x, b*y, name='z') + writer.add_graph(tf.get_default_graph()) + writer.flush() + sess = tf.Session() + sess.run(z, feed_dict = {x: 2, y: 3}) + builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING]) + builder.save() diff --git a/tensorflow-java/src/test/java/org/baeldung/tensorflow/TensorflowGraphUnitTest.java b/tensorflow-java/src/test/java/org/baeldung/tensorflow/TensorflowGraphUnitTest.java new file mode 100644 index 0000000000..51df6a4322 --- /dev/null +++ b/tensorflow-java/src/test/java/org/baeldung/tensorflow/TensorflowGraphUnitTest.java @@ -0,0 +1,21 @@ +package org.baeldung.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; + +public class TensorflowGraphUnitTest { + + @Test + public void givenTensorflowGraphWhenRunInSessionReturnsExpectedResult() { + + Graph graph = TensorflowGraph.createGraph(); + Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0); + assertEquals(21.0, result); + System.out.println(result); + graph.close(); + + } + +}