Capsule Networks- Part III Reconstruction Loss
May 23, 2023
The last major part of the Capsule Network Paper is its novel approach to calculating loss. This article contains code that builds on the previous articles walking through the capsules and dynamic routing. This network needs to allow two inputs: trainable weights in the loss function and two different behaviors for training and inference. In order to accomplish this, we need to utilize custom tensorflow keras layers. Both the image and the label need to be passed into this model for training in this model. This is because the label is used in order to guide the reconstruction of the digit classification. In addition to deriving loss from the predicted output and the label classification, the loss function needs the input image to derive another aspect of the loss from the reconstructed image. During training, the reconstructed image will be based on the actual label which should help better guide the weights in the loss function. During inference, the reconstructed image will be based on the prediction of the model and the final loss value should be harsher than during training. The code in this article is quite similar to the previous articles but is implemented in custom tensorflow keras layers.
We must create a custom tensorflow keras layer which implements multiple methods or functions. One of these functions is for the custom loss function. Another of these functions is to output a dynamic routing value that will be used in both the loss function and the main call function of the class. The main call function of the class is what activates the function to calculate the custom loss. In this way, we are able to allow each of these functions to accept multiple inputs. The main reason why choosing the route of utilizing a custom tensorflow keras layer is that we needed to create different behaviors during training and inference. Custom layers allow us to specify a "training" variable that is automatically recognized in the custom layer for when the model is in training of inference mode. This takes some getting used to but the main nuance is that the call function in the custom layer needs to pass this training variable to another function in the class where the training and inference differentiation takes place. Otherwise, the default training values will be used.
In the previous article, I found that softmaxing over the caps1_n_caps axis during the dynamic routing rounds produced the best results which coincided with what was written in the research paper. However, the article in this code performs better when softmaxing over the caps2_n_caps axis during the dynamic routing rounds. Strangely, better results were achieved when softmaxing over caps1_n_caps for all rounds except for the last round (please leave an explanation in the comments if you can). See code below.
import tensorflow as tf
(trainimages, trainlabels), (testimages, testlabels) = tf.keras.datasets.mnist.load_data()
trains, tests = 1000,1000 #depends on how much time you have and how fast your python environment can process the data
trainimages, trainlabels = trainimages[:trains].reshape(trains,28,28,1), trainlabels[:trains] # needs a channel dimension on the last dimension for conv2d layer or else will not work; phots usually have 3 channels for color, but not mnist
testimages, testlabels = testimages[:tests].reshape(tests,28,28,1), testlabels[:tests]
tftrainimages = tf.cast(trainimages, dtype=tf.float32)
tftestimages = tf.cast(testimages, dtype=tf.float32)
epsilon = 1e-7
convoutsize = 6 #len and height value of 2nd convolutional output; not including filters
caps1_n_dims = 8 # this is the number of values or 'dims" in the unit vectors we are creating after the conv2d reshaping
caps1_n_maps = 32 # this is a multiple of the caps1_n_dims that we need to determine the number of filters for conv1, conv2, and number of capsules we are creating; both should be the same for conv1 and conv2
caps1_n_caps = caps1_n_maps * convoutsize**2 # this is the number of capsules we are creating which have caps1_n_dims number of values or "dims"
caps2_n_caps = 10 # number of label classifications = 10; for 0 - 9
caps2_n_dims = 16 # this is for the next set of capsules
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28
class matmult(tf.keras.layers.Layer): # must create custom keras Layer if initializing custom weights; this allows the weights to be trainable
def __init__(self, w, **kwargs):
super(matmult, self).__init__(**kwargs)
self.w = w
self.multip = tf.keras.layers.Multiply()
def call(self, inputs):
matmul = self.multip([self.w, inputs]) # double dim itemwise multiplication; [batch_size,1152,1,1,8] * [batch_size,1152,10,16,8] = [batch_size, 1152,10,16,8]
return tf.reduce_sum(matmul, axis = -1, keepdims = True) # matmul between W & caps1_output_expanded2
class clayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.conv1d = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 1, padding = 'valid', activation = 'relu')
self.conv2d = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 2, padding = 'valid', activation = 'relu')
self.convfiltreshape = tf.keras.layers.Reshape(target_shape=[ caps1_n_caps, caps1_n_dims])
initializer = tf.sqrt(2/(caps1_n_caps* caps1_n_dims)) #he-normal initializer
W_init = tf.random.normal( shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims), stddev=initializer, dtype=tf.float32)
self.W = tf.Variable(W_init, trainable = True)
self.matmultd = matmult(w = self.W)
self.caps1_output_expanded2d = tf.keras.layers.Reshape([caps1_n_caps, 1, 1, caps1_n_dims])
self.reconstruction_mask_reshapedd = tf.keras.layers.Reshape([-1, 1, caps2_n_caps, 1, 1])
self.decoder_inputd = tf.keras.layers.Reshape([-1, caps2_n_caps * caps2_n_dims])
self.hidden1d = tf.keras.layers.Dense(n_hidden1, activation=tf.nn.relu, trainable = True)
self.hidden2d = tf.keras.layers.Dense(n_hidden2,activation=tf.nn.relu, trainable = True)
self.decoder_outputd = tf.keras.layers.Dense(n_output, activation=tf.nn.sigmoid, trainable = True)
self.X_flatd = tf.keras.layers.Flatten()
def returnlabel(self, input0):
conv1 = self.conv1d(input0)
conv2 = self.conv2d(conv1)
caps1_raw = self.convfiltreshape(conv2)
squared_norm = tf.reduce_sum(tf.square(caps1_raw), axis=-1, keepdims=True)
safe_norm = tf.sqrt(squared_norm + epsilon) #represents len of vector; eps makes it numerically safe for backpropagation and significantly improves model performance
unit_vector = caps1_raw / safe_norm # represents caps1_n_caps number of unit vectors; #getting slightly better results without using squash_factor outside dynamic routing layers
squash_factor = squared_norm / (1. + squared_norm)
squashed_unit_vector = squash_factor * unit_vector
caps1_output_expanded2 = self.caps1_output_expanded2d(squashed_unit_vector) # tf.reshape creates warning because not official Keras Layer; for official keras Reshape Layer, dont include batch_size in dimensions or doesn't work; takes car of batch_size automatically
caps2_predicted = self.matmultd(caps1_output_expanded2) #[batch_size,1152,10,16,1]; this represents 10 classification scores based on caps1_n_caps, caps1_n_dims, caps2_n_dims
# routing round 1
raw_weights_round_1 = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1], dtype=tf.float32) #need to define batch_size here; important that all weights are equal here
routing_weights_round_1 = tf.nn.softmax(raw_weights_round_1, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_1 = routing_weights_round_1 * caps2_predicted
weighted_sum = tf.reduce_sum(weighted_predictions_round_1, axis=1, keepdims=True) # sums across caps1_n_caps axis; output shape [batch_size,1,10,16,1]
weighted_sum_sumsquared = tf.reduce_sum(tf.square(weighted_sum), axis=-2, keepdims=True) #squash function across caps2_n_dims axis;
weighted_sum_len = tf.sqrt(weighted_sum_sumsquared + epsilon)
weighted_sum_uv = weighted_sum / weighted_sum_len
squash_factor = weighted_sum_sumsquared / (1. + weighted_sum_sumsquared) #dynamic routing doesn't work without squash factor
caps2_output_round_1 = squash_factor * weighted_sum_uv # output shape [batch_size,1,10,16,1]
squashmult_round_1 = caps2_predicted * caps2_output_round_1 # itemwise multiplication across caps1_n_caps axis; output shape [batch_size,1152,10,16,1]
agreement_round_1 = tf.reduce_sum(squashmult_round_1, axis = -2, keepdims = True) # sum across caps2_n_dims axis; ; output shape [batch_size,1152,10,1,1]
# routing round 2
raw_weights_round_2 = raw_weights_round_1 + agreement_round_1 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_2 = routing_weights_round_2 * caps2_predicted
weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2, axis=1, keepdims=True)
weighted_sum_sumsquared_round_2 = tf.reduce_sum(tf.square(weighted_sum_round_2), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_2 = tf.sqrt(weighted_sum_sumsquared_round_2 + epsilon)
weighted_sum_uv_round_2 = weighted_sum_round_2 / weighted_sum_len_round_2
squash_factor_round_2 = weighted_sum_sumsquared_round_2 / (1. + weighted_sum_sumsquared_round_2)
caps2_output_round_2 = squash_factor_round_2 * weighted_sum_uv_round_2
squashmult_round_2 = caps2_predicted * caps2_output_round_2
agreement_round_2 = tf.reduce_sum(squashmult_round_2, axis = -2, keepdims = True)
# routing round 3
raw_weights_round_3 = raw_weights_round_2 + agreement_round_2 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_3 = tf.nn.softmax(raw_weights_round_3, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_3 = routing_weights_round_3 * caps2_predicted
weighted_sum_round_3 = tf.reduce_sum(weighted_predictions_round_3, axis=1, keepdims=True)
weighted_sum_sumsquared_round_3 = tf.reduce_sum(tf.square(weighted_sum_round_3), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_3 = tf.sqrt(weighted_sum_sumsquared_round_3 + epsilon)
weighted_sum_uv_round_3 = weighted_sum_round_3 / weighted_sum_len_round_3
squash_factor_round_3 = weighted_sum_sumsquared_round_3 / (1. + weighted_sum_sumsquared_round_3)
caps2_output_round_3 = squash_factor_round_3 * weighted_sum_uv_round_3
squashmult_round_3 = caps2_predicted * caps2_output_round_3
agreement_round_3 = tf.reduce_sum(squashmult_round_3, axis = -2, keepdims = True)
# routing round 4; no agreement needed for dynamic routing output
raw_weights_round_4 = raw_weights_round_3 + agreement_round_3 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_4 = tf.nn.softmax(raw_weights_round_4, axis = 2) #seems to work best when last round is done on axis = 2, while previous rounds axis = 1
weighted_predictions_round_4 = routing_weights_round_4 * caps2_predicted
weighted_sum_round_4 = tf.reduce_sum(weighted_predictions_round_4, axis=1, keepdims=True)
weighted_sum_sumsquared_round_4 = tf.reduce_sum(tf.square(weighted_sum_round_4), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_4 = tf.sqrt(weighted_sum_sumsquared_round_4 + epsilon)
weighted_sum_uv_round_4 = weighted_sum_round_4 / weighted_sum_len_round_4
squash_factor_round_4 = weighted_sum_sumsquared_round_4 / (1. + weighted_sum_sumsquared_round_4)
caps2_output_round_4 = squash_factor_round_4 * weighted_sum_uv_round_4
return caps2_output_round_4
def custom_loss(self, input0, input2, training = True):
caps2_output = self.returnlabel(input0)
squared_norm = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=False)
y_proba = tf.sqrt(squared_norm + epsilon)
y_proba_argmax = tf.argmax(y_proba, axis = 2)
ypred = tf.squeeze(y_proba_argmax, axis = [-1,-2]) #need to specify axes to ensure the output is the same shape as labels
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
alpha = 0.0005
caps2_output_norm0 = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=True)
caps2_output_norm = tf.sqrt(squared_norm + epsilon)
T = tf.one_hot(input2, depth=caps2_n_caps)
present_error_raw = tf.square(tf.maximum(0., m_plus - tf.cast(caps2_output_norm, dtype = tf.float32)))
present_error = tf.squeeze(present_error_raw)
absent_error_raw = tf.square(tf.maximum(0., tf.cast(caps2_output_norm, dtype = tf.float32) - m_minus))
absent_error = tf.squeeze(absent_error_raw)
L = tf.cast(T, dtype = tf.float32) * present_error + lambda_ * (1.0 - tf.cast(T, dtype = tf.float32)) * absent_error
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1))
if (training == True):
reconstruction_targets = input2
print("praise God train")
else:
reconstruction_targets = ypred
print("praise God test")
reconstruction_mask = tf.one_hot(reconstruction_targets, depth=caps2_n_caps)
reconstruction_mask_reshaped = self.reconstruction_mask_reshapedd(reconstruction_mask)
caps2_output_masked = caps2_output * reconstruction_mask_reshaped
decoder_input = self.decoder_inputd(caps2_output_masked)
hidden1 = self.hidden1d(decoder_input)
hidden2 = self.hidden2d(hidden1)
decoder_output = self.decoder_outputd(hidden2)
X_flat = self.X_flatd(input0)
squared_difference = tf.square(X_flat - decoder_output)
reconstruction_loss = tf.reduce_mean(squared_difference)
loss = margin_loss + alpha * reconstruction_loss
return loss
def call(self, input0, input2, training = True):
caps2_output = self.returnlabel(input0)
squared_norm = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=False)
y_proba = tf.sqrt(squared_norm + epsilon)
y_proba_argmax = tf.argmax(y_proba, axis = 2)
y_pred = tf.squeeze(y_proba_argmax, axis = [-1,-2]) #need to specify axes to ensure the output is the same shape as labels
self.add_loss(self.custom_loss(input0, input2, training = training)) #need to pass the training value in this call method to the custom loss function; otherwise custom loss will use default boolean value
return y_pred
input0 = tf.keras.Input(shape = (28,28,1))
input2 = tf.keras.Input(shape = (), dtype = tf.uint8) #cannot be of shape 1 or model wont work
outputs = clayer()(input0, input2)
model = tf.keras.Model(inputs = [input0,input2], outputs = outputs)
batch_size = tf.Variable(10) #though used to create model and layers, this doesn't need to be declared till here
def customacc(ytrue, ypred):
correct = tf.equal(ytrue, ypred)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
metrics=[customacc])
model.fit([tftrainimages, trainlabels], trainlabels, validation_data = ([tftestimages, testlabels], testlabels), batch_size = batch_size, epochs = 1)
These same results can be attained by using a custom tensorflow keras model, which is provided below. You will have to set the random seed like in the first capsule network article.
The last two articles in this series appeared to become less about capsules and more about practical tools to add to one's ML arsenal. The researchers claimed that the capsules consistently emerged and represented certain traits of a handwritten digit such as width, skew, and thickness. This series on capsule networks have taught us that the reshape after the convolutional layers simply regroups filters into smaller groups, while still maintaining consistent image section locations across filters. Each of these filters represent a combination of image pixel features at a certain location on the input image. It is possible that the weights in the reconstruction loss function can structure itself to work with the "convolutional layer output" after being transformed by a second capsule layer and dynamic routing in order to represent width, skew, or thickness. This could be why reconstruction loss has been created. This creates a way to recreate digits based on the weights in the model, while also being dependent on the convolutional layer output. This allows low level characteristics such as pixel combinations to be combined, strengthened, and transformed, into high level characteristics such as width, skew, and thickness of the entire digit as a whole. Check out Aurélien Géron's work that provided a ground work for the code provided in this article. He also attempted to explore the effect of capsules on a reconstructed image as described in the paper.
Capsule networks appear to provide an alternative to brute force additions of convolutional layers to arrive to higher level pixel understanding of images. This is the norm in object detection and classification models which utilize over a hundred layers, largely based on convolutions.
import tensorflow as tf
(trainimages, trainlabels), (testimages, testlabels) = tf.keras.datasets.mnist.load_data()
trains, tests = 1000,1000
trainimages, trainlabels = trainimages[:trains].reshape(trains,28,28,1), trainlabels[:trains] # needs a channel dimension on the last dimension for conv2d layer or else will not work; phots usually have 3 channels for color, but not mnist
testimages, testlabels = testimages[:tests].reshape(tests,28,28,1), testlabels[:tests]
tftrainimages = tf.cast(trainimages, dtype=tf.float32)
tftestimages = tf.cast(testimages, dtype=tf.float32)
tftrainlabels = tf.cast(trainlabels, dtype=tf.int8)
tftestlabels = tf.cast(testlabels, dtype=tf.int8)
epsilon = 1e-7
convoutsize = 6
caps1_n_dims = 8 # this is the number of values or 'dims" in the unit vectors we are creating after the conv2d reshaping
caps1_n_maps = 32 # this is a multiple of the caps1_n_dims that we need to determine the number of filters for conv1, conv2, and number of capsules we are creating; both should be the same for conv1 and conv2
caps1_n_caps = caps1_n_maps * convoutsize**2 # this is the number of capsules we are creating which have caps1_n_dims number of values or "dims"
caps2_n_caps = 10 # number of label classifications = 10; for 0 - 9
caps2_n_dims = 16 # this is for the next set of capsules
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28
class matmult(tf.keras.layers.Layer): # must create custom keras Layer if initializing custom weights; this allows the weights to be trainable
def __init__(self, w, **kwargs):
super(matmult, self).__init__(**kwargs)
self.w = w
self.multip = tf.keras.layers.Multiply()
def call(self, inputs, training = True):
matmul = self.multip([self.w, inputs]) # double dim itemwise multiplication; [batch_size,1152,1,1,8] * [batch_size,1152,10,16,8] = [batch_size, 1152,10,16,8]
return tf.reduce_sum(matmul, axis = -1, keepdims = True) # matmul between W & caps1_output_expanded2
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1d = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 1, padding = 'valid', activation = 'relu')
self.conv2d = tf.keras.layers.Conv2D(filters = caps1_n_maps * caps1_n_dims, kernel_size = 9, strides = 2, padding = 'valid', activation = 'relu')
self.convfiltreshape = tf.keras.layers.Reshape(target_shape=[ caps1_n_caps, caps1_n_dims])
initializer = tf.sqrt(2/(caps1_n_caps* caps1_n_dims)) #he-normal initializer
W_init = tf.random.normal( shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims), stddev=initializer, dtype=tf.float32)
self.W = tf.Variable(W_init, trainable = True)
self.matmultd = matmult(w = self.W)
self.caps1_output_expanded2d = tf.keras.layers.Reshape([caps1_n_caps, 1, 1, caps1_n_dims])
self.reconstruction_mask_reshapedd = tf.keras.layers.Reshape([-1, 1, caps2_n_caps, 1, 1])
self.decoder_inputd = tf.keras.layers.Reshape([-1, caps2_n_caps * caps2_n_dims])
self.hidden1d = tf.keras.layers.Dense(n_hidden1, activation=tf.nn.relu, trainable = True)
self.hidden2d = tf.keras.layers.Dense(n_hidden2,activation=tf.nn.relu, trainable = True)
self.decoder_outputd = tf.keras.layers.Dense(n_output, activation=tf.nn.sigmoid, trainable = True)
self.X_flatd = tf.keras.layers.Flatten()
def returnlabel(self, input0, training = True):
conv1 = self.conv1d(input0)
conv2 = self.conv2d(conv1)
caps1_raw = self.convfiltreshape(conv2)
squared_norm = tf.reduce_sum(tf.square(caps1_raw), axis=-1, keepdims=True)
safe_norm = tf.sqrt(squared_norm + epsilon) #represents len of vector; eps makes it numerically safe for backpropagation and significantly improves model performance
unit_vector = caps1_raw / safe_norm # represents caps1_n_caps number of unit vectors; #getting slightly better results without using squash_factor outside dynamic routing layers
squash_factor = squared_norm / (1. + squared_norm)
squashed_unit_vector = squash_factor * unit_vector
caps1_output_expanded2 = self.caps1_output_expanded2d(squashed_unit_vector) # tf.reshape creates warning because not official Keras Layer; for official keras Reshape Layer, dont include batch_size in dimensions or doesn't work; takes car of batch_size automatically
caps2_predicted = self.matmultd(caps1_output_expanded2) #[batch_size,1152,10,16,1]; this represents 10 classification scores based on caps1_n_caps, caps1_n_dims, caps2_n_dims
# routing round 1
raw_weights_round_1 = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1], dtype=tf.float32) #need to define batch_size here; important that all weights are equal here
routing_weights_round_1 = tf.nn.softmax(raw_weights_round_1, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_1 = routing_weights_round_1 * caps2_predicted
weighted_sum = tf.reduce_sum(weighted_predictions_round_1, axis=1, keepdims=True) # sums across caps1_n_caps axis; output shape [batch_size,1,10,16,1]
weighted_sum_sumsquared = tf.reduce_sum(tf.square(weighted_sum), axis=-2, keepdims=True) #squash function across caps2_n_dims axis;
weighted_sum_len = tf.sqrt(weighted_sum_sumsquared + epsilon)
weighted_sum_uv = weighted_sum / weighted_sum_len
squash_factor = weighted_sum_sumsquared / (1. + weighted_sum_sumsquared) #dynamic routing doesn't work without squash factor
caps2_output_round_1 = squash_factor * weighted_sum_uv # output shape [batch_size,1,10,16,1]
squashmult_round_1 = caps2_predicted * caps2_output_round_1 # itemwise multiplication across caps1_n_caps axis; output shape [batch_size,1152,10,16,1]
agreement_round_1 = tf.reduce_sum(squashmult_round_1, axis = -2, keepdims = True) # sum across caps2_n_dims axis; ; output shape [batch_size,1152,10,1,1]
# routing round 2
raw_weights_round_2 = raw_weights_round_1 + agreement_round_1 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_2 = routing_weights_round_2 * caps2_predicted
weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2, axis=1, keepdims=True)
weighted_sum_sumsquared_round_2 = tf.reduce_sum(tf.square(weighted_sum_round_2), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_2 = tf.sqrt(weighted_sum_sumsquared_round_2 + epsilon)
weighted_sum_uv_round_2 = weighted_sum_round_2 / weighted_sum_len_round_2
squash_factor_round_2 = weighted_sum_sumsquared_round_2 / (1. + weighted_sum_sumsquared_round_2)
caps2_output_round_2 = squash_factor_round_2 * weighted_sum_uv_round_2
squashmult_round_2 = caps2_predicted * caps2_output_round_2
agreement_round_2 = tf.reduce_sum(squashmult_round_2, axis = -2, keepdims = True)
# routing round 3
raw_weights_round_3 = raw_weights_round_2 + agreement_round_2 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_3 = tf.nn.softmax(raw_weights_round_3, axis = 1) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_3 = routing_weights_round_3 * caps2_predicted
weighted_sum_round_3 = tf.reduce_sum(weighted_predictions_round_3, axis=1, keepdims=True)
weighted_sum_sumsquared_round_3 = tf.reduce_sum(tf.square(weighted_sum_round_3), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_3 = tf.sqrt(weighted_sum_sumsquared_round_3 + epsilon)
weighted_sum_uv_round_3 = weighted_sum_round_3 / weighted_sum_len_round_3
squash_factor_round_3 = weighted_sum_sumsquared_round_3 / (1. + weighted_sum_sumsquared_round_3)
caps2_output_round_3 = squash_factor_round_3 * weighted_sum_uv_round_3
squashmult_round_3 = caps2_predicted * caps2_output_round_3
agreement_round_3 = tf.reduce_sum(squashmult_round_3, axis = -2, keepdims = True)
# routing round 4; no agreement needed for dynamic routing output
raw_weights_round_4 = raw_weights_round_3 + agreement_round_3 # adding agreement values to tf.zeros() baseline; raw_weights_round_2 = agreement_round_1
routing_weights_round_4 = tf.nn.softmax(raw_weights_round_4, axis = 2) #equal probabilities across number of classifications layer: 10% each at baseline.
weighted_predictions_round_4 = routing_weights_round_4 * caps2_predicted
weighted_sum_round_4 = tf.reduce_sum(weighted_predictions_round_4, axis=1, keepdims=True)
weighted_sum_sumsquared_round_4 = tf.reduce_sum(tf.square(weighted_sum_round_4), axis=-2, keepdims=True) #squash function
weighted_sum_len_round_4 = tf.sqrt(weighted_sum_sumsquared_round_4 + epsilon)
weighted_sum_uv_round_4 = weighted_sum_round_4 / weighted_sum_len_round_4
squash_factor_round_4 = weighted_sum_sumsquared_round_4 / (1. + weighted_sum_sumsquared_round_4)
caps2_output_round_4 = squash_factor_round_4 * weighted_sum_uv_round_4
return caps2_output_round_4
def custom_loss(self, input0, input2, training = True):
caps2_output = self.returnlabel(input0)
squared_norm = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=False)
y_proba = tf.sqrt(squared_norm + epsilon)
y_proba_argmax = tf.argmax(y_proba, axis = 2)
ypred = tf.squeeze(y_proba_argmax, axis = [-1,-2])
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
alpha = 0.0005
caps2_output_norm0 = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=True)
caps2_output_norm = tf.sqrt(squared_norm + epsilon)
T = tf.one_hot(tf.cast(input2, dtype=tf.int8), depth=caps2_n_caps)
present_error_raw = tf.square(tf.maximum(0., m_plus - tf.cast(caps2_output_norm, dtype = tf.float32)))
present_error = tf.squeeze(present_error_raw)
absent_error_raw = tf.square(tf.maximum(0., tf.cast(caps2_output_norm, dtype = tf.float32) - m_minus))
absent_error = tf.squeeze(absent_error_raw)
L = tf.cast(T, dtype = tf.float32) * present_error + lambda_ * (1.0 - tf.cast(T, dtype = tf.float32)) * absent_error
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1))
if (training == True):
reconstruction_targets = tf.cast(input2, dtype=tf.int8)
print("praise God Training")
else:
reconstruction_targets = tf.cast(ypred, dtype=tf.int8)#ypred
print("still praise God during testing")
reconstruction_mask = tf.one_hot(reconstruction_targets, depth=caps2_n_caps)
reconstruction_mask_reshaped = self.reconstruction_mask_reshapedd(reconstruction_mask)
caps2_output_masked = caps2_output * reconstruction_mask_reshaped
decoder_input = self.decoder_inputd(caps2_output_masked)
hidden1 = self.hidden1d(decoder_input)
hidden2 = self.hidden2d(hidden1)
decoder_output = self.decoder_outputd(hidden2)
X_flat = self.X_flatd(input0)
squared_difference = tf.square(X_flat - decoder_output)
reconstruction_loss = tf.reduce_mean(squared_difference)
loss = margin_loss + alpha * reconstruction_loss
return loss
def call(self, combinput, training = True): #only allows one input, which can be a tuple of multiple sub-inputs
input0, input2 = combinput
caps2_output = self.returnlabel(input0)
squared_norm = tf.reduce_sum(tf.square(caps2_output), axis=-2, keepdims=False)
y_proba = tf.sqrt(squared_norm + epsilon)
y_proba_argmax = tf.argmax(y_proba, axis = 2)
y_pred = tf.squeeze(y_proba_argmax, axis = [-1,-2])
self.add_loss(self.custom_loss(input0, input2, training = training))
return y_pred
input0 = tf.keras.Input(shape = (28,28,1))
input2 = tf.keras.Input(shape = ())
outputs = MyModel()((input0, input2)) #tuple of two inputs to satisfy one input in call function in model
model = MyModel()
batch_size.assign(10)
def customacc(ytrue, ypred):
correct = tf.equal(ytrue, ypred)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(0.001), metrics=[customacc])
model.fit([tftrainimages, tftrainlabels], tftrainlabels, validation_data = ([tftestimages, tftestlabels], tftestlabels), batch_size = batch_size, epochs = 1)