AI, Blog, Python

Creating a TF Lite model with RETVec

To use RETVec with TF Lite, it’s essential to have tensorflow_text version 2.13 or higher and tensorflow version 2.13 or higher. Instructions for upgrading TensorFlow can be found at the provided link. This notebook demonstrates the process of creating, saving, and operating a TF Lite compatible model that incorporates the RETVec tokenizer.

# installing retvec if needed
try:
    import retvec
except ImportError:
    !pip install retvec

try:
    import tensorflow_text
except ImportError:
    !pip install tensorflow-text
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # silence TF INFO messages
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers

# import the RETVec tokenizer layer
from retvec.tf import RETVecTokenizer

For RETVec to be compatible with TF Lite, it’s crucial to enable the setting use_tf_lite_compatible_ops=True. This adjustment ensures that the layer employs tensorflow_text.utf8_binarize along with whitespace splitting for dividing text into words, a method that TF Lite natively supports.

# using strings directly requires to put a shape of (1,) and dtype tf.string
inputs = layers.Input(shape=(1, ), name="input", dtype=tf.string)

# add RETVec tokenizer layer with `use_tf_lite_compatible_ops`
x = RETVecTokenizer(model='retvec-v1', use_tf_lite_compatible_ops=True)(inputs)

# build the rest of the model as usual
x = layers.Dense(256, activation='relu')(x)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(4, activation='sigmoid', name="output")(x)
model = tf.keras.Model(inputs, outputs)

model.summary()
#output
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input (InputLayer)          [(None, 1)]               0         
                                                                 
 ret_vec_tokenizer (RETVecT  (None, 128, 256)          230144    
 okenizer)                                                       
                                                                 
 dense (Dense)               (None, 128, 256)          65792     
                                                                 
 dense_1 (Dense)             (None, 128, 64)           16448     
                                                                 
 output (Dense)              (None, 128, 4)            260       
                                                                 
=================================================================
Total params: 312644 (1.19 MB)
Trainable params: 82500 (322.27 KB)
Non-trainable params: 230144 (899.00 KB)
_________________________________________________________________

Convert the model and run inference in TF Lite

Following the provided instructions, we are now able to transform the model into a TF Lite version. For additional details on utilizing TensorFlow Lite, refer to the accompanying guide.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(save_path) # path to the SavedModel directory
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
]
converter.allow_custom_ops = True
tflite_model = converter.convert()
from tensorflow.lite.python import interpreter
import tensorflow_text as tf_text

# create TF lite interpreter with TF Text ops registered
interp = interpreter.InterpreterWithCustomOps(
    model_content=tflite_model,
    custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
interp.allocate_tensors()

# run inference with model
input_data = np.array(['This is an example text'])

tokenize = interp.get_signature_runner('serving_default')
output = tokenize(input=input_data)
print('TensorFlow Lite result = ', output['output'])
#output TensorFlow Lite result =  [[[0.46520743 0.5190651  0.3716683  0.43701836]
  [0.6337548  0.42784083 0.5022397  0.55659497]
  [0.53433377 0.53684425 0.42378557 0.4369351 ]
  [0.406101   0.5063563  0.41558668 0.31651068]
  [0.55106455 0.54234135 0.49299878 0.23038922]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]
  [0.4838743  0.53117436 0.36593238 0.43153659]]]

Using RETVec to train an emotion classifier

RETVec stands as a cutting-edge text vectorization tool that directly processes text inputs, enabling the creation of robust classification models. According to our research paper, models trained using RETVec show enhanced classification efficacy with a reduced number of parameters and demonstrate increased resilience against adversarial attacks and typographical errors. The efficiency of RETVec, characterized by its significantly lower parameter count (approximately 200,000 as opposed to millions), makes it an ideal solution for developing and deploying compact, high-performing models on devices. It seamlessly integrates with TensorFlow Lite through custom operations in TensorFlow Text and also offers a JavaScript version of RETVec for deploying web models through TensorFlow.js. This notebook guides you through the swift process of training and employing a text emotion classifier using TensorFlow.
Let’s begin!

# installing needed dependencies
try:
    import retvec
except ImportError:
    !pip install retvec  # is retvec installed?

try:
    import datasets
except ImportError:
    !pip install datasets  # used to get the dataset

try:
    import matplotlib
except ImportError:
    !pip install matplotlib

Lets import Libraries

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # silence TF INFO messages
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from datasets import load_dataset
from matplotlib import pyplot as plt

Create dataset

We are going to use the Go Emotion dataset to create a mulit-class emotion classifier. https://ai.googleblog.com/2021/10/goemotions-dataset-for-fine-grained.html

# downloading dataset
dataset = load_dataset('go_emotions')
# get class name mapping and number of class
CLASSES = dataset['train'].features['labels'].feature.names
NUM_CLASSES = len(CLASSES)
print(f"num classes {NUM_CLASSES}")
print(CLASSES)
num classes 28
['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
# preparing data
x_train = tf.constant(dataset['train']['text'], dtype=tf.string)
# the one-hot requires a little more due to the multi-class nature of the dataset.
y_train = np.zeros((len(x_train),NUM_CLASSES))
for idx, ex in enumerate(dataset['train']['labels']):
    for val in ex:
        y_train[idx][val] = 1
# test data
x_test = tf.constant(dataset['test']['text'], dtype=tf.string)
y_test = np.zeros((len(x_test),NUM_CLASSES))
for idx, ex in enumerate(dataset['test']['labels']):
    for val in ex:
        y_test[idx][val] = 1

Model
One of the main advantages of RETVec is its ability to handle raw string inputs directly for both the RETVec and text model, without requiring any pre-processing. This significantly streamlines both the training and inference stages, particularly for models designed to run on devices.

Notes:

  • Using strings directly as input requires to use a shape of (1,) and specify the type tf.string
  •  RETVecTokenizer() in its default configuration which is to truncate at 128 words and use a small pre-trained word embedding model to embed the words. You can experiment with shorter or longer length by changing the sequence_length parameter. The word embedding model offers significant improvements in adversarial and typo robustness. To use the RETVec character tokenizer only, set model=None.
  • To use native TF ops only for TF Lite compatibility, set use_tf_lite_compatible_ops=True. see the TF Lite notebook for more details on how to convert a RETVec-based model to a TF Lite model which can run on-device.
# using strings directly requires to put a shape of (1,) and dtype tf.string
inputs = layers.Input(shape=(1, ), name="token", dtype=tf.string)

# add RETVec tokenizer layer with default settings -- this is all you have to do to build a model with RETVec!
x = RETVecTokenizer(model='retvec-v1')(inputs)

# standard two layer LSTM
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
x = layers.Bidirectional(layers.LSTM(64))(x)
outputs = layers.Dense(NUM_CLASSES, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
#output
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 token (InputLayer)          [(None, 1)]               0         
                                                                 
 ret_vec_tokenizer_1 (RETVe  (None, 128, 256)          230144    
 cTokenizer)                                                     
                                                                 
 bidirectional_2 (Bidirecti  (None, 128, 128)          164352    
 onal)                                                           
                                                                 
 bidirectional_3 (Bidirecti  (None, 128)               98816     
 onal)                                                           
                                                                 
 dense_1 (Dense)             (None, 28)                3612      
                                                                 
=================================================================
Total params: 496924 (1.90 MB)
Trainable params: 266780 (1.02 MB)
Non-trainable params: 230144 (899.00 KB)
_________________________________________________________________

Lets Train it

# compile and train the model like usual
batch_size = 256
epochs = 25
model.compile('adam', 'binary_crossentropy', ['acc'])
history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, 
                    validation_data=(x_test, y_test))
Epoch 1/25
170/170 [==============================] - 18s 79ms/step - loss: 0.1763 - acc: 0.2889 - val_loss: 0.1476 - val_acc: 0.2959
Epoch 2/25
170/170 [==============================] - 10s 62ms/step - loss: 0.1480 - acc: 0.2987 - val_loss: 0.1447 - val_acc: 0.3088
Epoch 3/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1420 - acc: 0.3298 - val_loss: 0.1359 - val_acc: 0.3532
Epoch 4/25
170/170 [==============================] - 11s 66ms/step - loss: 0.1350 - acc: 0.3648 - val_loss: 0.1294 - val_acc: 0.3910
Epoch 5/25
170/170 [==============================] - 10s 62ms/step - loss: 0.1296 - acc: 0.3930 - val_loss: 0.1248 - val_acc: 0.4111
Epoch 6/25
170/170 [==============================] - 11s 62ms/step - loss: 0.1251 - acc: 0.4137 - val_loss: 0.1207 - val_acc: 0.4301
Epoch 7/25
170/170 [==============================] - 12s 68ms/step - loss: 0.1213 - acc: 0.4323 - val_loss: 0.1181 - val_acc: 0.4450
Epoch 8/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1176 - acc: 0.4504 - val_loss: 0.1161 - val_acc: 0.4470
Epoch 9/25
170/170 [==============================] - 12s 70ms/step - loss: 0.1150 - acc: 0.4578 - val_loss: 0.1135 - val_acc: 0.4559
Epoch 10/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1124 - acc: 0.4670 - val_loss: 0.1105 - val_acc: 0.4691
Epoch 11/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1101 - acc: 0.4776 - val_loss: 0.1092 - val_acc: 0.4765
Epoch 12/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1082 - acc: 0.4841 - val_loss: 0.1079 - val_acc: 0.4793
Epoch 13/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1064 - acc: 0.4902 - val_loss: 0.1072 - val_acc: 0.4820
Epoch 14/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1050 - acc: 0.4949 - val_loss: 0.1057 - val_acc: 0.4855
Epoch 15/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1031 - acc: 0.5037 - val_loss: 0.1053 - val_acc: 0.4841
Epoch 16/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1020 - acc: 0.5063 - val_loss: 0.1042 - val_acc: 0.4924
Epoch 17/25
170/170 [==============================] - 12s 69ms/step - loss: 0.1005 - acc: 0.5105 - val_loss: 0.1045 - val_acc: 0.4922
Epoch 18/25
170/170 [==============================] - 12s 69ms/step - loss: 0.0993 - acc: 0.5153 - val_loss: 0.1019 - val_acc: 0.5038
Epoch 19/25
170/170 [==============================] - 12s 70ms/step - loss: 0.0978 - acc: 0.5220 - val_loss: 0.1025 - val_acc: 0.5040
Epoch 20/25
170/170 [==============================] - 12s 69ms/step - loss: 0.0970 - acc: 0.5235 - val_loss: 0.1019 - val_acc: 0.5115
Epoch 21/25
170/170 [==============================] - 12s 69ms/step - loss: 0.0960 - acc: 0.5286 - val_loss: 0.1013 - val_acc: 0.5089
Epoch 22/25
170/170 [==============================] - 12s 70ms/step - loss: 0.0947 - acc: 0.5340 - val_loss: 0.1016 - val_acc: 0.5091
Epoch 23/25
170/170 [==============================] - 12s 69ms/step - loss: 0.0937 - acc: 0.5373 - val_loss: 0.1017 - val_acc: 0.5054
Epoch 24/25
170/170 [==============================] - 12s 68ms/step - loss: 0.0928 - acc: 0.5426 - val_loss: 0.1010 - val_acc: 0.5106
Epoch 25/25
170/170 [==============================] - 12s 69ms/step - loss: 0.0915 - acc: 0.5474 - val_loss: 0.0999 - val_acc: 0.5165
# visualize the training curves
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.legend(['acc', 'val_acc'])
plt.title(f'Accuracy')
plt.show()

Save & Reload Keras Model
Let’s save our model, then test it on some examples.

# saving the model
save_path = 'demo_models/emotion_model'
model.save(save_path)
model = tf.keras.models.load_model(save_path, compile=False)
def predict_emotions(txt, threshold=0.5):
    # recall it is multi-class so we need to get all prediction above a threshold (0.5)
    preds = model(tf.constant([txt]))[0]
    out = 0
    for i in range(NUM_CLASSES):
        if preds[i] > threshold:
            emotion_name = CLASSES[i]
            emotion_prob = round(float(preds[i]) * 100, 1)
            print(f"{emotion_name} ({emotion_prob})%")
            out += 1
    if not out:
        print("neutral")

Lets try the function :

txt = "I enjoy having a good icecream."
predict_emotions(txt)
#output
joy (92.2)%
# the model works even with typos, substitutions, and emojis!
txt = "I enjoy hving a g00d ic3cream!!! 🍦"
predict_emotions(txt)
#output
joy (91.9)%

Leave a Reply