def debiasing_loss_func(x, x_pred, y_label, y_logit, z_mu, z_logsigma, kl_weight=0.005):
# compute loss components
reconstruction_loss = tf.reduce_mean(tf.keras.losses.MSE(x,x_pred), axis=(1,2))
classification_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_label, logits=y_logit)
kl_loss = 0.5 * tf.reduce_sum(tf.exp(z_logsigma) + tf.square(z_mu) - 1.0 - z_logsigma, axis=1)
# propogate debiasing gradients only on relevant datapoints
gradient_mask = tf.cast(tf.equal(y_label, 1), tf.float32)
# define the total debiasing loss as a combination of the three losses
vae_loss = kl_weight * kl_loss + reconstruction_loss
total_loss = tf.reduce_mean(classification_loss + gradient_mask * vae_loss)
return total_loss