Aprendizaje automático con Spark MLlib

1. Información general

En este tutorial, entenderemos cómo aprovechar Apache Spark MLlib para desarrollar productos de aprendizaje automático. Desarrollaremos un producto de aprendizaje automático simple con Spark MLlib para demostrar los conceptos básicos.

2. Breve introducción al aprendizaje automático

El aprendizaje automático es parte de un paraguas más amplio conocido como inteligencia artificial . El aprendizaje automático se refiere al estudio de modelos estadísticos para resolver problemas específicos con patrones e inferencias. Estos modelos son "entrenados" para el problema específico por medio de datos de entrenamiento extraídos del espacio del problema.

Veremos qué implica exactamente esta definición a medida que adoptemos nuestro ejemplo.

2.1. Categorías de aprendizaje automático

Podemos categorizar ampliamente el aprendizaje automático en categorías supervisadas y no supervisadas según el enfoque. También hay otras categorías, pero nos limitaremos a estas dos:

  • El aprendizaje supervisado funciona con un conjunto de datos que contiene tanto las entradas como la salida deseada , por ejemplo, un conjunto de datos que contiene varias características de una propiedad y los ingresos por alquiler esperados. El aprendizaje supervisado se divide a su vez en dos amplias subcategorías denominadas clasificación y regresión:
    • Los algoritmos de clasificación están relacionados con la salida categórica, como si una propiedad está ocupada o no
    • Los algoritmos de regresión están relacionados con un rango de salida continuo, como el valor de una propiedad
  • El aprendizaje no supervisado, por otro lado, funciona con un conjunto de datos que solo tienen valores de entrada . Funciona intentando identificar la estructura inherente en los datos de entrada. Por ejemplo, encontrar diferentes tipos de consumidores a través de un conjunto de datos de su comportamiento de consumo.

2.2. Flujo de trabajo de aprendizaje automático

El aprendizaje automático es verdaderamente un área de estudio interdisciplinar. Requiere conocimiento del dominio empresarial, estadística, probabilidad, álgebra lineal y programación. Como esto claramente puede resultar abrumador, es mejor abordarlo de manera ordenada , lo que normalmente llamamos un flujo de trabajo de aprendizaje automático:

Como podemos ver, cada proyecto de aprendizaje automático debe comenzar con una declaración de problema claramente definida. Esto debe ir seguido de una serie de pasos relacionados con los datos que potencialmente pueden responder al problema.

Luego, normalmente seleccionamos un modelo que analiza la naturaleza del problema. A esto le sigue una serie de entrenamiento y validación del modelo, lo que se conoce como ajuste fino del modelo. Finalmente, probamos el modelo en datos no vistos anteriormente y lo implementamos en producción si es satisfactorio.

3. ¿Qué es Spark MLlib ?

Spark MLlib es un módulo sobre Spark Core que proporciona primitivas de aprendizaje automático como API. El aprendizaje automático generalmente trata con una gran cantidad de datos para el entrenamiento de modelos.

El marco informático base de Spark es un gran beneficio. Además de esto, MLlib proporciona la mayoría de los algoritmos estadísticos y de aprendizaje automático populares. Esto simplifica enormemente la tarea de trabajar en un proyecto de aprendizaje automático a gran escala.

4. Aprendizaje automático con MLlib

Ahora tenemos suficiente contexto sobre el aprendizaje automático y cómo MLlib puede ayudar en este esfuerzo. Comencemos con nuestro ejemplo básico de implementación de un proyecto de aprendizaje automático con Spark MLlib.

Si recordamos nuestra discusión sobre el flujo de trabajo de aprendizaje automático, deberíamos comenzar con una declaración del problema y luego pasar a los datos. Afortunadamente para nosotros, elegiremos el "hola mundo" del aprendizaje automático, Iris Dataset. Este es un conjunto de datos etiquetado multivariante, que consta de la longitud y el ancho de sépalos y pétalos de diferentes especies de Iris.

Esto nos da el objetivo de nuestro problema: ¿podemos predecir la especie de un iris a partir del largo y ancho de su sépalo y pétalo ?

4.1. Establecer las dependencias

Primero, tenemos que definir la siguiente dependencia en Maven para extraer las bibliotecas relevantes:

 org.apache.spark spark-mllib_2.11 2.4.3 provided 

Y necesitamos inicializar SparkContext para que funcione con las API de Spark:

SparkConf conf = new SparkConf() .setAppName("Main") .setMaster("local[2]"); JavaSparkContext sc = new JavaSparkContext(conf);

4.2. Cargando los datos

Lo primero es lo primero, debemos descargar los datos, que están disponibles como un archivo de texto en formato CSV. Luego tenemos que cargar estos datos en Spark:

String dataFile = "data\\iris.data"; JavaRDD data = sc.textFile(dataFile);

Spark MLlib ofrece varios tipos de datos, tanto locales como distribuidos, para representar los datos de entrada y las etiquetas correspondientes. El más simple de los tipos de datos es Vector :

JavaRDD inputData = data .map(line -> { String[] parts = line.split(","); double[] v = new double[parts.length - 1]; for (int i = 0; i < parts.length - 1; i++) { v[i] = Double.parseDouble(parts[i]); } return Vectors.dense(v); });

Tenga en cuenta que aquí solo hemos incluido las funciones de entrada, principalmente para realizar análisis estadísticos.

Un ejemplo de entrenamiento generalmente consta de múltiples características de entrada y una etiqueta, representada por la clase La labelPoint :

Map map = new HashMap(); map.put("Iris-setosa", 0); map.put("Iris-versicolor", 1); map.put("Iris-virginica", 2); JavaRDD labeledData = data .map(line -> { String[] parts = line.split(","); double[] v = new double[parts.length - 1]; for (int i = 0; i < parts.length - 1; i++) { v[i] = Double.parseDouble(parts[i]); } return new LabeledPoint(map.get(parts[parts.length - 1]), Vectors.dense(v)); });

Nuestra etiqueta de salida en el conjunto de datos es textual, lo que significa la especie de Iris. Para alimentar esto en un modelo de aprendizaje automático, tenemos que convertirlo en valores numéricos.

4.3. Análisis exploratorio de datos

El análisis de datos exploratorios implica analizar los datos disponibles. Ahora, los algoritmos de aprendizaje automático son sensibles a la calidad de los datos , por lo que los datos de mayor calidad tienen mejores perspectivas de ofrecer el resultado deseado.

Los objetivos de análisis típicos incluyen eliminar anomalías y detectar patrones. Esto incluso se incorpora a los pasos críticos de la ingeniería de funciones para llegar a funciones útiles a partir de los datos disponibles.

Our dataset, in this example, is small and well-formed. Hence we don't have to indulge in a lot of data analysis. Spark MLlib, however, is equipped with APIs to offer quite an insight.

Let's begin with some simple statistical analysis:

MultivariateStatisticalSummary summary = Statistics.colStats(inputData.rdd()); System.out.println("Summary Mean:"); System.out.println(summary.mean()); System.out.println("Summary Variance:"); System.out.println(summary.variance()); System.out.println("Summary Non-zero:"); System.out.println(summary.numNonzeros());

Here, we're observing the mean and variance of the features we have. This is helpful in determining if we need to perform normalization of features. It's useful to have all features on a similar scale. We are also taking a note of non-zero values, which can adversely impact model performance.

Here is the output for our input data:

Summary Mean: [5.843333333333332,3.0540000000000003,3.7586666666666666,1.1986666666666668] Summary Variance: [0.6856935123042509,0.18800402684563744,3.113179418344516,0.5824143176733783] Summary Non-zero: [150.0,150.0,150.0,150.0]

Another important metric to analyze is the correlation between features in the input data:

Matrix correlMatrix = Statistics.corr(inputData.rdd(), "pearson"); System.out.println("Correlation Matrix:"); System.out.println(correlMatrix.toString());

A high correlation between any two features suggests they are not adding any incremental value and one of them can be dropped. Here is how our features are correlated:

Correlation Matrix: 1.0 -0.10936924995064387 0.8717541573048727 0.8179536333691672 -0.10936924995064387 1.0 -0.4205160964011671 -0.3565440896138163 0.8717541573048727 -0.4205160964011671 1.0 0.9627570970509661 0.8179536333691672 -0.3565440896138163 0.9627570970509661 1.0

4.4. Splitting the Data

If we recall our discussion of machine learning workflow, it involves several iterations of model training and validation followed by final testing.

For this to happen, we have to split our training data into training, validation, and test sets. To keep things simple, we'll skip the validation part. So, let's split our data into training and test sets:

JavaRDD[] splits = parsedData.randomSplit(new double[] { 0.8, 0.2 }, 11L); JavaRDD trainingData = splits[0]; JavaRDD testData = splits[1];

4.5. Model Training

So, we've reached a stage where we've analyzed and prepared our dataset. All that's left is to feed this into a model and start the magic! Well, easier said than done. We need to pick a suitable algorithm for our problem – recall the different categories of machine learning we spoke of earlier.

It isn't difficult to understand that our problem fits into classification within the supervised category. Now, there are quite a few algorithms available for use under this category.

The simplest of them is Logistic Regression (let the word regression not confuse us; it is, after all, a classification algorithm):

LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(3) .run(trainingData.rdd());

Here, we are using a three-class Limited Memory BFGS based classifier. The details of this algorithm are beyond the scope of this tutorial, but this is one of the most widely used ones.

4.6. Model Evaluation

Remember that model training involves multiple iterations, but for simplicity, we've just used a single pass here. Now that we've trained our model, it's time to test this on the test dataset:

JavaPairRDD predictionAndLabels = testData .mapToPair(p -> new Tuple2(model.predict(p.features()), p.label())); MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); double accuracy = metrics.accuracy(); System.out.println("Model Accuracy on Test Data: " + accuracy);

Now, how do we measure the effectiveness of a model? There are several metrics that we can use, but one of the simplest is Accuracy. Simply put, accuracy is a ratio of the correct number of predictions and the total number of predictions. Here is what we can achieve in a single run of our model:

Model Accuracy on Test Data: 0.9310344827586207

Note that this will vary slightly from run to run due to the stochastic nature of the algorithm.

However, accuracy is not a very effective metric in some problem domains. Other more sophisticated metrics are Precision and Recall (F1 Score), ROC Curve, and Confusion Matrix.

4.7. Saving and Loading the Model

Finally, we often need to save the trained model to the filesystem and load it for prediction on production data. This is trivial in Spark:

model.save(sc, "model\\logistic-regression"); LogisticRegressionModel sameModel = LogisticRegressionModel .load(sc, "model\\logistic-regression"); Vector newData = Vectors.dense(new double[]{1,1,1,1}); double prediction = sameModel.predict(newData); System.out.println("Model Prediction on New Data = " + prediction);

So, we're saving the model to the filesystem and loading it back. After loading, the model can be straight away used to predict output on new data. Here is a sample prediction on random new data:

Model Prediction on New Data = 2.0

5. Beyond the Primitive Example

While the example we went through covers the workflow of a machine learning project broadly, it leaves a lot of subtle and important points. While it isn't possible to discuss them in detail here, we can certainly go through some of the important ones.

Spark MLlib through its APIs has extensive support in all these areas.

5.1. Model Selection

Model selection is often one of the complex and critical tasks. Training a model is an involved process and is much better to do on a model that we're more confident will produce the desired results.

While the nature of the problem can help us identify the category of machine learning algorithm to pick from, it isn't a job fully done. Within a category like classification, as we saw earlier, there are often many possible different algorithms and their variations to choose from.

Often the best course of action is quick prototyping on a much smaller set of data. A library like Spark MLlib makes the job of quick prototyping much easier.

5.2. Model Hyper-Parameter Tuning

A typical model consists of features, parameters, and hyper-parameters. Features are what we feed into the model as input data. Model parameters are variables which model learns during the training process. Depending on the model, there are certain additional parameters that we have to set based on experience and adjust iteratively. These are called model hyper-parameters.

For instance, the learning rate is a typical hyper-parameter in gradient-descent based algorithms. Learning rate controls how fast parameters are adjusted during training cycles. This has to be aptly set for the model to learn effectively at a reasonable pace.

While we can begin with an initial value of such hyper-parameters based on experience, we have to perform model validation and manually tune them iteratively.

5.3. Model Performance

A statistical model, while being trained, is prone to overfitting and underfitting, both causing poor model performance. Underfitting refers to the case where the model does not pick the general details from the data sufficiently. On the other hand, overfitting happens when the model starts to pick up noise from the data as well.

There are several methods for avoiding the problems of underfitting and overfitting, which are often employed in combination. For instance, to counter overfitting, the most employed techniques include cross-validation and regularization. Similarly, to improve underfitting, we can increase the complexity of the model and increase the training time.

Spark MLlib has fantastic support for most of these techniques like regularization and cross-validation. In fact, most of the algorithms have default support for them.

6. Spark MLlib in Comparision

While Spark MLlib is quite a powerful library for machine learning projects, it is certainly not the only one for the job. There are quite a number of libraries available in different programming languages with varying support. We'll go through some of the popular ones here.

6.1. Tensorflow/Keras

Tensorflow is an open-source library for dataflow and differentiable programming, widely employed for machine learning applications. Together with its high-level abstraction, Keras, it is a tool of choice for machine learning. They are primarily written in Python and C++ and primarily used in Python. Unlike Spark MLlib, it does not have a polyglot presence.

6.2. Theano

Theano is another Python-based open-source library for manipulating and evaluating mathematical expressions – for instance, matrix-based expressions, which are commonly used in machine learning algorithms. Unlike Spark MLlib, Theano again is primarily used in Python. Keras, however, can be used together with a Theano back end.

6.3. CNTK

Microsoft Cognitive Toolkit (CNTK) is a deep learning framework written in C++ that describes computational steps via a directed graph. It can be used in both Python and C++ programs and is primarily used in developing neural networks. There's a Keras back end based on CNTK available for use that provides the familiar intuitive abstraction.

7. Conclusion

To sum up, in this tutorial we went through the basics of machine learning, including different categories and workflow. We went through the basics of Spark MLlib as a machine learning library available to us.

Furthermore, we developed a simple machine learning application based on the available dataset. We implemented some of the most common steps in the machine learning workflow in our example.

También pasamos por algunos de los pasos avanzados en un proyecto típico de aprendizaje automático y cómo Spark MLlib puede ayudar en ellos. Finalmente, vimos algunas de las bibliotecas alternativas de aprendizaje automático disponibles para que las usemos.

Como siempre, el código se puede encontrar en GitHub.