Capsule Networks- Part II Dynamic Routing
May 22, 2023
In the previous article, we filtered through the Capsule Net paper to understand what capsules are and how they work. Dynamic routing or routing-by-agreement is another technique used in the paper. The only difference between the code in this article and the previous article is the process from caps2_predicted to the outputs. In the previous article, caps2_predicted is summed across the caps1_n_caps. Then the length across caps2_n_dims is determined, which creates a probability score for each output digit classification.
The summary of the previous article with my notations are below. "*+" means matrix multiplication and "@sum" signifies which axis undergoes a weighted sum or function specified with "@". The brackets after the name of the layer indicates the shape of the layer. This notation is only for keeping track of what is going on conceptually between layers starting from caps2_predicted to the outputs.
caps2_predicted[batch_size, 1152, 10, 16,1] = caps1_output_expanded[batch_size,1152,1,1,sum@8] *+ weights[batch_size,1152,10,16,sum@8]
caps2_predicted_reducedsum[batch_size,1,10,16,1] = caps2_predicted[batch_size,sum@1152,10,16,1]
outputs[batch_size, 10] = caps2_predicted_reducedsum[batch_size,1,10,len@16,1].squeeze()
To demonstrate the recursive nature of dynamic routing, the code provided repeats the process for up to 4 rounds. The conceptual summary ends at caps2_output_round_2 and is converted to the outputs layer. Dynamic routing can end at any round. However, there comes a point where more dynamic routing rounds no longer improves the model.
caps2_predicted[batch_size,1152,10,16,1] = caps1_output_expanded[batch_size,1152,1,1,sum@8] *+ weights[batch_size,1152,10,16,sum@8]
caps2_predicted_reducedsum[batch_size,1,10,16,1] = caps2_predicted[batch_size,sum@1152,10,16,1]
caps2_output_round_1[batch_size,1,10,16,1] = caps2_predicted_reducedsum[batch_size,1,10,squashuv@16,1]
agreement_round_1[batch_size,1152,10,1,1] = caps2_predicted[batch_size,1152,10,sum@16,1] *+ caps2_output_round_1[batch_size,1,10,sum@16,1]
raw_weights_round_2[batch_size,1152,10,1,1] = agreement_round_1[batch_size,1152,10,1,1] + raw_weights_round_1[batch_size,1152,10,1,1]
routing_weights_round_2[batch_size,1152,10,1,1] = raw_weights_round_2[batch_size,softmax@1152,10,1,1]
weighted_predictions_round_2[batch_size,1152,10,16,1] = routing_weights_round_2[batch_size,1152,10,itemwise@1,1] * caps2_predicted[batch_size,1152,10,16,1]
weighted_sum_round_2[batch_size,1,10,16,1] = weighted_predictions_round_2[batch_size,sum@1152,10,16,1]
caps2_output_round_2[batch_size,1,10,16,1] = weighted_sum_round_2[batch_size,1,10,squasheduv@16,1]
outputs[batch_size, 10] = caps2_output_round_2[batch_size,1,10,len@16,1].squeeze()
In this summary, caps2_predicted undergoes a more complicated process than the previous article, but since it is itemwise multiplied by initial raw_weights with all the same values due to the softmax function, that step is discarded from the summary. It is important to grasp that the current round's raw_weights are determined by adding the previous round's raw_weights and agreement values. Also, notice how caps2_predicted and caps2_output_round_2 are equivalent in shape. Both undergo the same transformations into the outputs layer. Dynamic routing first undergoes a multiplication and softmax with the raw_weights. Regarding the softmax of the raw_weights in dynamic routing, the paper states: "The coupling coefficients between capsule i and all the capsules in the layer above sum to 1 and are determined by a “routing softmax” whose initial logits bij are the log prior probabilities that capsule i should be coupled to capsule j". In this step we are specifically trying to determine routing_weights that demonstrate a meaningful connection between the first and second capsules. In effect, applying a softmax to the caps1_n_caps axis, demonstrates the probability that a specific caps1_n_caps should be coupled to a specific caps2_n_caps final digit classification output. This is counterintuitive in the sense that softmaxing generally occurs in the output classification axis or "caps2_n_caps". In fact, testing actually shows that softmaxing over the caps2_n_caps axis in the dynamic routing rounds produces significantly worse results.
Once this multiplication and softmax occurs with the raw_weights, the code in this article and the previous article sum across the caps1_n_caps axis. This summation leads to useful information in the second capsule layer(final digit classification outputs) categorized by the number of dimensions in the second capsule layer: "caps2_n_dims". Unlike the previous article, dynamic routing then calculates the unit vector across caps2_n_dims and then multiplies the value by the squash factor. The larger the values in caps2_predicted, the larger the ratio applied to the unit vector and the larger the length of the squashed vector. The paper states: "We therefore use a non-linear 'squashing' function to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1". The length of these squashed vectors across caps2_n_dims are determined and represent probability scores of each final output digit classification.
We can see that after matrix multiplying caps2_predicted and caps2_output_round_1, we arrive at agreement values of caps2_predicted between the raw_weights_round_1, summed across the caps2_n_dims axis. In essence, caps2_predicted is multiplied by a transformed version of itself (caps2_output_round_1). This agreement value is then added to the raw weights of the same round to determine the raw weights of the following round. Notice that caps2_predicted was initially summed across caps1_n_caps. Then, to determine agreement, caps2_output_round_1 is expanded through itemwise multiplication across caps1_n_caps. Whereas caps2_predicted maintains the caps2_n_dims, caps2_predicted is eventually summed across caps2_n_dims in order to determine agreement values.
The complexity of dynamic routing can't be explained by trying to understand the mathematical logic between each layer. As the paper states: "Active capsules at one level make predictions, via transformation matrices". It could be possible that all that is occuring is logical, mathematical transformations from one capsule to the next; and needs to be programmed in such a way that useful loss/error information can be backpropagated through correct reverse transformations. Although no weights are trainable during dynamic routing, the gradients may be amplified by the recursive nature of "routing by agreement." This can send stronger loss information to the trainable weights layer. Doing more rounds can strengthen these signals but eventually reveal diminishing returns on the model's accuracy.
Just like in the first round, the raw_weights are softmaxed across the caps1_n_caps axis, itemwise multiplied by caps2_predicted, summed across the caps1_n_caps axis, the unit vector across caps2_n_dims is calculated, and multiplied by the squash factor. This can be transformed to outputs by determining the length of each final output digit classification across caps2_n_dims.
The nuance of capsule networks is the intentionality of which axes are being multiplied and summed. After examining the mathematical logic involved in dynamic routing, it is my opinion that this is analogous to manipulating the learning rate for a specific part of the model: the "caps2_predicted" layer. In the next article, there will be trainable weights that come after the dynamic routing rounds that don't get the amplified signals that dynamic routing produces. Although the complexity of dynamic routing seems novel, testing shows that it is not absolutely necessary for capsule networks to demonstrate learning. Though conceptually unnecessary to understanding the capsules, dynamic routing does produce significantly better results by strengthening the relevant connections between layers. See the commented code below.
import tensorflow as tf
(trainimages, trainlabels), (testimages, testlabels) = tf.keras.datasets.mnist.load_data()
epsilon = 1e-7
batch_size = 10 # since batch_size is used in the model to create the routing weights, this needs to be declared at the beginning and can't be changed
trains, tests = 1000,1000
trainimages, trainlabels = trainimages[:trains].reshape(trains,28,28,1), trainlabels[:trains] # needs a channel dimension on the last dimension for conv2d layer or else will not work; phots usually have 3 channels for color, but not mnist
testimages, testlabels = testimages[:tests].reshape(tests,28,28,1), testlabels[:tests]
ohtrainlabels = tf.one_hot(trainlabels,10) #this one hot encoded labels; just makes training easier
ohtestlabels = tf.one_hot(testlabels,10)
convoutsize = 6 #len and height value of 2nd convolutional output; not including filters
caps1_n_dims = 8 # this is the number of values or 'dims" in the unit vectors we are creating after the conv2d reshaping
caps1_n_maps = 32 # this is a multiple of the caps1_n_dims that we need to determine the number of filters for conv1, conv2, and number of capsules we are creating; both should be the same for conv1 and conv2
caps1_n_caps = caps1_n_maps * convoutsize**2 # this is the number of capsules we are creating which have caps1_n_dims number of values or "dims"
inputs = tf.keras.Input(shape = (trainimages.shape[1],trainimages.shape[2],1))
conv1 = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 1, activation = tf.nn.relu)(inputs) #creats pixel feature map
conv2 = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 2, activation = tf.nn.relu)(conv1) #this layer is a pixel feature combination map
caps1_raw = tf.keras.layers.Reshape([ caps1_n_caps, caps1_n_dims]) (conv2) # reshapes into caps1_n_caps number of capsules of size caps_1_n_dims
squared_norm = tf.reduce_sum(tf.square(caps1_raw), axis=-1, keepdims=True)
safe_norm = tf.sqrt(squared_norm + epsilon) #represents len of vector; eps makes it numerically safe for backpropagation and significantly improves model performance
unit_vector = caps1_raw / safe_norm # represents caps1_n_caps number of unit vectors; #getting slightly better results without using squash_factor outside dynamic routing layers
# squash_factor = squared_norm / (1. + squared_norm)
# squashed_unit_vector = squash_factor * unit_vector
caps2_n_caps = 10 # number of label classifications = 10; for 0 - 9
caps2_n_dims = 16 # this is for the next set of capsules
caps1_output_expanded2 = tf.keras.layers.Reshape([caps1_n_caps, 1, 1, caps1_n_dims])(unit_vector) # tf.reshape creates warning because not official Keras Layer; for official keras Reshape Layer, dont include batch_size in dimensions or doesn't work; takes car of batch_size automatically
# caps1_output_tiled = tf.tile(caps1_output_expanded2, [1,1, caps2_n_caps, caps2_n_dims, 1]) # since not an official keras Layer, must include batch_size in dimensions; keras.layers.Multiply can still itemwise multiply over 2 dimensions; [?,1152,1,1,8] * [?,1152,10,16,8]
class matmult(tf.keras.layers.Layer): # must create custom keras Layer if initializing custom weights; this allows the weights to be trainable
def __init__(self, **kwargs):
super(matmult, self).__init__(**kwargs)
initializer = tf.sqrt(2/(caps1_n_caps* caps1_n_dims)) #he-normal initializer
W_init = tf.random.normal( shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims), stddev=initializer, dtype=tf.float32)
self.W = tf.Variable(W_init, trainable = True)
self.multip = tf.keras.layers.Multiply()
def call(self, inputs):
matmult = self.multip([self.W, inputs]) # double dim itemwise multiplication; [batch_size,1152,1,1,8] * [batch_size,1152,10,16,8] = [batch_size, 1152,10,16,8]
return tf.reduce_sum(matmult, axis = -1, keepdims = True) # matmul between W & caps1_output_expanded2
caps2_predicted = matmult()(caps1_output_expanded2) #[batch_size,1152,10,16,1]; this represents 10 classification scores based on caps1_n_caps, caps1_n_dims, caps2_n_dims
# routing round 1
raw_weights_round_1 = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1], dtype=tf.float32) #need to define batch_size here; important that all weights are equal here
routing_weights_round_1 = tf.keras.layers.Softmax(axis = 1)(raw_weights_round_1) #equal probabilities across number of classifications layer: 10% each at baseline; for whatever reason works better on softmax(axis=1)
weighted_predictions = tf.keras.layers.Multiply()([routing_weights_round_1, caps2_predicted]) # itemwise multiplication across caps2_n_dims axis; = [batch_size, 1152, 10, 16, 1] = [batch_size, 1152, 10, 16, 1] * [batch_size, 1152, 10, 1, 1]; multiplying caps2_predicted by the same ratio (10%)
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True) # sums across caps1_n_caps axis; output shape [batch_size,1,10,16,1]
weighted_sum_sumsquared = tf.reduce_sum(tf.square(weighted_sum), axis=-2, keepdims=True) #squash function across caps2_n_dims axis;
weighted_sum_len = tf.sqrt(weighted_sum_sumsquared + epsilon)
weighted_sum_uv = weighted_sum / weighted_sum_len
squash_factor = weighted_sum_sumsquared / (1. + weighted_sum_sumsquared) #dynamic routing doesn't work without squash factor
caps2_output_round_1 = tf.keras.layers.Multiply()([squash_factor, weighted_sum_uv]) # output shape [batch_size,1,10,16,1]
squashmult_round_1 = tf.keras.layers.Multiply()([caps2_predicted, caps2_output_round_1]) # itemwise multiplication across caps1_n_caps axis; output shape [batch_size,1152,10,16,1]
agreement_round_1 = tf.reduce_sum(squashmult_round_1, axis = -2, keepdims = True) # sum across caps2_n_dims axis; ; output shape [batch_size,1152,10,1,1]
# routing round 2
raw_weights_round_2 = tf.keras.layers.Add()([raw_weights_round_1, agreement_round_1]) # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_2 = tf.keras.layers.Softmax(axis = 1)(raw_weights_round_2)
weighted_predictions_round_2 = tf.keras.layers.Multiply()([routing_weights_round_2, caps2_predicted])
weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2, axis=1, keepdims=True)
weighted_sum_sumsquared_round_2 = tf.reduce_sum(tf.square(weighted_sum_round_2), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_2 = tf.sqrt(weighted_sum_sumsquared_round_2 + epsilon)
weighted_sum_uv_round_2 = weighted_sum_round_2 / weighted_sum_len_round_2
squash_factor_round_2 = weighted_sum_sumsquared_round_2 / (1. + weighted_sum_sumsquared_round_2)
caps2_output_round_2 = squash_factor_round_2 * weighted_sum_uv_round_2
squashmult_round_2 = tf.keras.layers.Multiply()([caps2_predicted, caps2_output_round_2])
agreement_round_2 = tf.reduce_sum(squashmult_round_2, axis = -2, keepdims = True)
# routing round 3
raw_weights_round_3 = tf.keras.layers.Add()([raw_weights_round_2, agreement_round_2])
routing_weights_round_3 = tf.keras.layers.Softmax(axis = 1)(raw_weights_round_3)
weighted_predictions_round_3 = tf.keras.layers.Multiply()([routing_weights_round_3, caps2_predicted])
weighted_sum_round_3 = tf.reduce_sum(weighted_predictions_round_3, axis=1, keepdims=True)
weighted_sum_sumsquared_round_3 = tf.reduce_sum(tf.square(weighted_sum_round_3), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_3 = tf.sqrt(weighted_sum_sumsquared_round_3 + epsilon)
weighted_sum_uv_round_3 = weighted_sum_round_3 / weighted_sum_len_round_3
squash_factor_round_3 = weighted_sum_sumsquared_round_3 / (1. + weighted_sum_sumsquared_round_3)
caps2_output_round_3 = squash_factor_round_3 * weighted_sum_uv_round_3
squashmult_round_3 = tf.keras.layers.Multiply()([caps2_predicted, caps2_output_round_3])
agreement_round_3 = tf.reduce_sum(squashmult_round_3, axis = -2, keepdims = True)
# routing round 4; no agreement needed for dynamic routing output
raw_weights_round_4 = tf.keras.layers.Add()([raw_weights_round_3, agreement_round_3])
routing_weights_round_4 = tf.keras.layers.Softmax(axis = 1)(raw_weights_round_4)
weighted_predictions_round_4 = tf.keras.layers.Multiply()([routing_weights_round_4, caps2_predicted])
weighted_sum_round_4 = tf.reduce_sum(weighted_predictions_round_4, axis=1, keepdims=True)
weighted_sum_sumsquared_round_4 = tf.reduce_sum(tf.square(weighted_sum_round_4), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_4 = tf.sqrt(weighted_sum_sumsquared_round_4 + epsilon)
weighted_sum_uv_round_4 = weighted_sum_round_4 / weighted_sum_len_round_4
squash_factor_round_4 = weighted_sum_sumsquared_round_4 / (1. + weighted_sum_sumsquared_round_4)
caps2_output_round_4 = squash_factor_round_4 * weighted_sum_uv_round_4
squared_norm_final = tf.reduce_sum(tf.square(caps2_output_round_4), axis=-2, keepdims=True)
y_proba = tf.sqrt(squared_norm_final + epsilon) # this represents mathematical length of vector of size caps2_n_dims dims; once again epsilon dramatically improves performance because of tf.sqrt()
outputs = tf.keras.layers.Reshape([10])(y_proba) # converts to onehot prediction, by removing dimensions of size 1
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()])
model.fit(trainimages, ohtrainlabels, validation_data = (testimages, ohtestlabels), batch_size = batch_size, epochs = 1)