Introducción a Tensorflow para Java

1. Información general

TensorFlow es una biblioteca de código abierto para la programación de flujo de datos . Esto fue desarrollado originalmente por Google y está disponible para una amplia gama de plataformas. Aunque TensorFlow puede funcionar en un solo núcleo, también puede beneficiarse fácilmente de múltiples CPU, GPU o TPU disponibles .

En este tutorial, repasaremos los conceptos básicos de TensorFlow y cómo usarlo en Java. Tenga en cuenta que la API de TensorFlow Java es una API experimental y, por lo tanto, no está cubierta por ninguna garantía de estabilidad. Más adelante, en el tutorial, cubriremos posibles casos de uso para usar la API de TensorFlow Java.

2. Conceptos básicos

El cálculo de TensorFlow básicamente gira en torno a dos conceptos fundamentales: gráfico y sesión . Repasemos rápidamente para obtener los antecedentes necesarios para seguir con el resto del tutorial.

2.1. Gráfico de TensorFlow

Para empezar, comprendamos los componentes básicos de los programas de TensorFlow. Los cálculos se representan como gráficos en TensorFlow . Un gráfico suele ser un gráfico acíclico dirigido de operaciones y datos, por ejemplo:

La imagen de arriba representa el gráfico computacional de la siguiente ecuación:

f(x, y) = z = a*x + b*y

Un gráfico computacional de TensorFlow consta de dos elementos:

  1. Tensor: esta es la unidad principal de datos en TensorFlow. Se representan como los bordes en un gráfico computacional, que representan el flujo de datos a través del gráfico. Un tensor puede tener una forma con cualquier número de dimensiones. El número de dimensiones en un tensor generalmente se conoce como su rango. Entonces, un escalar es un tensor de rango 0, un vector es un tensor de rango 1, una matriz es un tensor de rango 2, y así sucesivamente.
  2. Operación: Estos son los nodos en un gráfico computacional. Se refieren a una amplia variedad de cálculos que pueden ocurrir en los tensores que alimentan la operación. A menudo también dan como resultado tensores que emanan de la operación en un gráfico computacional.

2.2. Sesión de TensorFlow

Ahora, un gráfico de TensorFlow es un simple esquema del cálculo que en realidad no tiene valores. Dicho gráfico debe ejecutarse dentro de lo que se llama una sesión de TensorFlow para que se evalúen los tensores en el gráfico . La sesión puede tomar varios tensores para evaluarlos desde un gráfico como parámetros de entrada. Luego, se ejecuta hacia atrás en el gráfico y ejecuta todos los nodos necesarios para evaluar esos tensores.

Con este conocimiento, ¡ahora estamos listos para tomar esto y aplicarlo a la API de Java!

3. Configuración de Maven

Configuraremos un proyecto rápido de Maven para crear y ejecutar un gráfico de TensorFlow en Java. Solo necesitamos la dependencia de tensorflow :

 org.tensorflow tensorflow 1.12.0 

4. Crear el gráfico

Intentemos ahora construir el gráfico que discutimos en la sección anterior usando la API de TensorFlow Java. Más precisamente, para este tutorial usaremos la API de Java TensorFlow para resolver la función representada por la siguiente ecuación:

z = 3*x + 2*y

El primer paso es declarar e inicializar un gráfico:

Graph graph = new Graph()

Ahora, tenemos que definir todas las operaciones necesarias. Recuerde que las operaciones en TensorFlow consumen y producen cero o más tensores . Además, cada nodo del gráfico es una operación que incluye constantes y marcadores de posición. Esto puede parecer contrario a la intuición, ¡pero ten paciencia por un momento!

La clase Graph tiene una función genérica llamada opBuilder () para crear cualquier tipo de operación en TensorFlow.

4.1. Definición de constantes

Para empezar, definamos operaciones constantes en nuestro gráfico anterior. Tenga en cuenta que una operación constante necesitará un tensor para su valor :

Operation a = graph.opBuilder("Const", "a") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(3.0, Double.class)) .build(); Operation b = graph.opBuilder("Const", "b") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(2.0, Double.class)) .build();

Aquí, hemos definido una Operación de tipo constante, alimentando el Tensor con valores Double 2.0 y 3.0. Puede parecer un poco abrumador para empezar, pero así es en la API de Java por ahora. Estas construcciones son mucho más concisas en lenguajes como Python.

4.2. Definición de marcadores de posición

Si bien necesitamos proporcionar valores a nuestras constantes, los marcadores de posición no necesitan un valor en el momento de la definición . Los valores de los marcadores de posición deben proporcionarse cuando el gráfico se ejecuta dentro de una sesión. Repasaremos esa parte más adelante en el tutorial.

Por ahora, veamos cómo podemos definir nuestros marcadores de posición:

Operation x = graph.opBuilder("Placeholder", "x") .setAttr("dtype", DataType.fromClass(Double.class)) .build(); Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.fromClass(Double.class)) .build();

Tenga en cuenta que no tuvimos que proporcionar ningún valor para nuestros marcadores de posición. Estos valores se alimentarán como tensores cuando se ejecuten.

4.3. Definición de funciones

Finalmente, necesitamos definir las operaciones matemáticas de nuestra ecuación, a saber, multiplicación y suma para obtener el resultado.

De nuevo, estos no son más que Operation s en TensorFlow y Graph.opBuilder () es útil una vez más:

Operation ax = graph.opBuilder("Mul", "ax") .addInput(a.output(0)) .addInput(x.output(0)) .build(); Operation by = graph.opBuilder("Mul", "by") .addInput(b.output(0)) .addInput(y.output(0)) .build(); Operation z = graph.opBuilder("Add", "z") .addInput(ax.output(0)) .addInput(by.output(0)) .build();

Aquí hemos definido su Operación , dos para multiplicar nuestras entradas y la final para sumar los resultados intermedios. Tenga en cuenta que las operaciones aquí reciben tensores que no son más que el resultado de nuestras operaciones anteriores.

Tenga en cuenta que estamos obteniendo el tensor de salida de la Operación usando el índice '0'. Como discutimos anteriormente, una Operación puede resultar en uno o más Tensor y, por lo tanto, mientras recuperamos un identificador, debemos mencionar el índice. Como sabemos que nuestras operaciones solo devuelven un tensor , ¡'0' funciona bien!

5. Visualización del gráfico

Es difícil mantener una pestaña en el gráfico a medida que aumenta de tamaño. Esto hace que sea importante visualizarlo de alguna manera . Siempre podemos crear un dibujo a mano como el gráfico pequeño que creamos anteriormente, pero no es práctico para gráficos más grandes. TensorFlow proporciona una utilidad llamada TensorBoard para facilitar esto .

Desafortunadamente, la API de Java no tiene la capacidad de generar un archivo de eventos que consume TensorBoard. Pero usando API en Python podemos generar un archivo de eventos como:

writer = tf.summary.FileWriter('.') ...... writer.add_graph(tf.get_default_graph()) writer.flush()

No se moleste si esto no tiene sentido en el contexto de Java, esto se ha agregado aquí solo para completar y no es necesario para continuar con el resto del tutorial.

Ahora podemos cargar y visualizar el archivo de eventos en TensorBoard como:

tensorboard --logdir .

TensorBoard viene como parte de la instalación de TensorFlow.

¡Tenga en cuenta la similitud entre este y el gráfico dibujado manualmente anteriormente!

6. Trabajar con sesión

Ahora hemos creado un gráfico computacional para nuestra ecuación simple en la API de TensorFlow Java. ¿Pero cómo lo ejecutamos? Antes de abordar eso, veamos cuál es el estado de Graph que acabamos de crear en este punto. Si intentamos imprimir la salida de nuestra Operación final "z":

System.out.println(z.output(0));

Esto resultará en algo como:


    

¡Esto no es lo que esperábamos! Pero si recordamos lo que discutimos anteriormente, esto realmente tiene sentido. El gráfico que acabamos de definir aún no se ha ejecutado, por lo que los tensores en él no tienen ningún valor real. La salida anterior solo dice que este será un tensor de tipo Double .

Definamos ahora una sesión para ejecutar nuestro gráfico :

Session sess = new Session(graph)

Finally, we are now ready to run our Graph and get the output we have been expecting:

Tensor tensor = sess.runner().fetch("z") .feed("x", Tensor.create(3.0, Double.class)) .feed("y", Tensor.create(6.0, Double.class)) .run().get(0).expect(Double.class); System.out.println(tensor.doubleValue());

So what are we doing here? It should be fairly intuitive:

  • Get a Runner from the Session
  • Define the Operation to fetch by its name “z”
  • Feed in tensors for our placeholders “x” and “y”
  • Run the Graph in the Session

And now we see the scalar output:

21.0

This is what we expected, isn't it!

7. The Use Case for Java API

At this point, TensorFlow may sound like overkill for performing basic operations. But, of course, TensorFlow is meant to run graphs much much larger than this.

Additionally, the tensors it deals with in real-world models are much larger in size and rank. These are the actual machine learning models where TensorFlow finds its real use.

It's not difficult to see that working with the core API in TensorFlow can become very cumbersome as the size of the graph increases. To this end, TensorFlow provides high-level APIs like Keras to work with complex models. Unfortunately, there is little to no official support for Keras on Java just yet.

However, we can use Python to define and train complex models either directly in TensorFlow or using high-level APIs like Keras. Subsequently, we can export a trained model and use that in Java using the TensorFlow Java API.

Now, why would we want to do something like that? This is particularly useful for situations where we want to use machine learning enabled features in existing clients running on Java. For instance, recommending caption for user images on an Android device. Nevertheless, there are several instances where we are interested in the output of a machine learning model but do not necessarily want to create and train that model in Java.

This is where TensorFlow Java API finds the bulk of its use. We'll go through how this can be achieved in the next section.

8. Using Saved Models

We'll now understand how we can save a model in TensorFlow to the file system and load that back possibly in a completely different language and platform. TensorFlow provides APIs to generate model files in a language and platform neutral structure called Protocol Buffer.

8.1. Saving Models to the File System

We'll begin by defining the same graph we created earlier in Python and saving that to the file system.

Let's see we can do this in Python:

import tensorflow as tf graph = tf.Graph() builder = tf.saved_model.builder.SavedModelBuilder('./model') with graph.as_default(): a = tf.constant(2, name="a") b = tf.constant(3, name="b") x = tf.placeholder(tf.int32, name="x") y = tf.placeholder(tf.int32, name="y") z = tf.math.add(a*x, b*y, name="z") sess = tf.Session() sess.run(z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING]) builder.save()

As the focus of this tutorial in Java, let's not pay much attention to the details of this code in Python, except for the fact that it generates a file called “saved_model.pb”. Do note in passing the brevity in defining a similar graph compared to Java!

8.2. Loading Models from the File System

We'll now load “saved_model.pb” into Java. Java TensorFlow API has SavedModelBundle to work with saved models:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("z") .feed("x", Tensor.create(3, Integer.class)) .feed("y", Tensor.create(3, Integer.class)) .run().get(0).expect(Integer.class); System.out.println(tensor.intValue());

It should by now be fairly intuitive to understand what the above code is doing. It simply loads the model graph from the protocol buffer and makes available the session therein. From there onward, we can pretty much do anything with this graph as we would have done for a locally-defined graph.

9. Conclusion

To sum up, in this tutorial we went through the basic concepts related to the TensorFlow computational graph. We saw how to use the TensorFlow Java API to create and run such a graph. Then, we talked about the use cases for the Java API with respect to TensorFlow.

In the process, we also understood how to visualize the graph using TensorBoard, and save and reload a model using Protocol Buffer.

Como siempre, el código de los ejemplos está disponible en GitHub.