Adding source code for the tutorial tracked under BAEL-2759 (#6533)

This commit is contained in:
Kumar Chandrakant
2019-03-15 11:22:27 +05:30
committed by Grzegorz Piwowarek
parent b3fc27088b
commit dc72b8b397
8 changed files with 158 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
package org.baeldung.tensorflow;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
public class TensorflowGraph {
public static Graph createGraph() {
Graph graph = new Graph();
Operation a = graph.opBuilder("Const", "a").setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class)).build();
Operation b = graph.opBuilder("Const", "b").setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class)).build();
Operation x = graph.opBuilder("Placeholder", "x").setAttr("dtype", DataType.fromClass(Double.class)).build();
Operation y = graph.opBuilder("Placeholder", "y").setAttr("dtype", DataType.fromClass(Double.class)).build();
Operation ax = graph.opBuilder("Mul", "ax").addInput(a.output(0)).addInput(x.output(0)).build();
Operation by = graph.opBuilder("Mul", "by").addInput(b.output(0)).addInput(y.output(0)).build();
graph.opBuilder("Add", "z").addInput(ax.output(0)).addInput(by.output(0)).build();
return graph;
}
public static Object runGraph(Graph graph, Double x, Double y) {
Object result;
try (Session sess = new Session(graph)) {
result = sess.runner().fetch("z").feed("x", Tensor.<Double>create(x, Double.class))
.feed("y", Tensor.<Double>create(y, Double.class)).run().get(0).expect(Double.class)
.doubleValue();
}
return result;
}
public static void main(String[] args) {
Graph graph = TensorflowGraph.createGraph();
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
System.out.println(result);
graph.close();
}
}

View File

@@ -0,0 +1,14 @@
package org.baeldung.tensorflow;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
public class TensorflowSavedModel {
public static void main(String[] args) {
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class))
.feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());
}
}

View File

@@ -0,0 +1,16 @@
import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
writer = tf.summary.FileWriter('.')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
writer.add_graph(tf.get_default_graph())
writer.flush()
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()

View File

@@ -0,0 +1,21 @@
package org.baeldung.tensorflow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.Test;
import org.tensorflow.Graph;
public class TensorflowGraphUnitTest {
@Test
public void givenTensorflowGraphWhenRunInSessionReturnsExpectedResult() {
Graph graph = TensorflowGraph.createGraph();
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
assertEquals(21.0, result);
System.out.println(result);
graph.close();
}
}