model = tf.keras.models.load_model(model_path)
@tf.function(input_signature=[tf.TensorSpec([1, max_length], dtype=tf.int32)
def pred_fn(encoded_text):
# Create a mask that masks out 0 tokens, and future tokens for next word prediction.
mask = create_padded_lookahead_mask(max_length)
# Our saved model outputs both its next word predictions, and the activations of its final layer. We only use the activations of the final layer for clustering purposes.
model_preds, last_layer_output = model([encoded_text, mask], training=False)
# Max pool over the seq dimension.
return tf.reduce_max(last_layer_output, axis=1)
converter = tf.lite.TFLiteConverter.from_concrete_functions([fn.get_concrete_function()])
tflite_model = converter.convert()