Keras models in modAL workflows

Thanks for the scikit-learn API of Keras, you can seamlessly integrate Keras models into your modAL workflow. In this tutorial, we shall quickly introduce how to use the scikit-learn API of Keras and we are going to see how to do active learning with it. More details on the Keras scikit-learn API can be found here.

The executable script for this example can be found here!

Keras’ scikit-learn API

By default, a Keras model’s interface differs from what is used for scikit-learn estimators. However, with the use of its scikit-learn wrapper, it is possible to adapt your model.

[1]:
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.wrappers.scikit_learn import KerasClassifier

# build function for the Keras' scikit-learn API
def create_keras_model():
    """
    This function compiles and returns a Keras model.
    Should be passed to KerasClassifier in the Keras scikit-learn API.
    """

    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

    return model
/home/namazu/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.

For our purposes, the classifier which we will initialize now acts just like any scikit-learn estimator.

[2]:
# create the classifier
classifier = KerasClassifier(create_keras_model)

Active learning with Keras

In this example, we are going to use the famous MNIST dataset, which is available as a built-in for Keras.

[3]:
import numpy as np
from keras.datasets import mnist

# read training data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255
X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# assemble initial data
n_initial = 1000
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

# generate the pool
# remove the initial data from the training dataset
X_pool = np.delete(X_train, initial_idx, axis=0)[:5000]
y_pool = np.delete(y_train, initial_idx, axis=0)[:5000]

Active learning with data and classifier ready is as easy as always. Because training is very expensive in large neural networks, this time we are going to query the best 200 instances each time we measure the uncertainty of the pool.

[4]:
from modAL.models import ActiveLearner

# initialize ActiveLearner
learner = ActiveLearner(
    estimator=classifier,
    X_training=X_initial, y_training=y_initial,
    verbose=1
)
Epoch 1/1
1000/1000 [==============================] - 4s 4ms/step - loss: 1.5794 - acc: 0.4790

To make sure that you train only on newly queried labels, pass only_new=True to the .teach() method of the learner.

[5]:
# the active learning loop
n_queries = 10
for idx in range(n_queries):
    print('Query no. %d' % (idx + 1))
    query_idx, query_instance = learner.query(X_pool, n_instances=100, verbose=0)
    learner.teach(
        X=X_pool[query_idx], y=y_pool[query_idx], only_new=True,
        verbose=1
    )
    # remove queried instance from pool
    X_pool = np.delete(X_pool, query_idx, axis=0)
    y_pool = np.delete(y_pool, query_idx, axis=0)
Query no. 1
Epoch 1/1
100/100 [==============================] - 1s 10ms/step - loss: 2.0987 - acc: 0.3300
Query no. 2
Epoch 1/1
100/100 [==============================] - 1s 7ms/step - loss: 2.1222 - acc: 0.3300
Query no. 3
Epoch 1/1
100/100 [==============================] - 1s 8ms/step - loss: 2.0558 - acc: 0.4900
Query no. 4
Epoch 1/1
100/100 [==============================] - 1s 9ms/step - loss: 1.6943 - acc: 0.4700
Query no. 5
Epoch 1/1
100/100 [==============================] - 1s 12ms/step - loss: 1.5865 - acc: 0.6200
Query no. 6
Epoch 1/1
100/100 [==============================] - 1s 14ms/step - loss: 1.8714 - acc: 0.3500
Query no. 7
Epoch 1/1
100/100 [==============================] - 1s 14ms/step - loss: 1.3940 - acc: 0.6700
Query no. 8
Epoch 1/1
100/100 [==============================] - 1s 14ms/step - loss: 2.1033 - acc: 0.3200
Query no. 9
Epoch 1/1
100/100 [==============================] - 1s 11ms/step - loss: 1.5666 - acc: 0.6700
Query no. 10
Epoch 1/1
100/100 [==============================] - 1s 12ms/step - loss: 2.0238 - acc: 0.2700