Pruning of neural networks with TensorFlow
The purpose of pruning of the weights based on magnitude is to gradually zero out the less significant weights of the model during the training phase
thus obtaining a certain degree of sparsity in the matrices of the weights (both kernel and bias).
For sparsity of a matrix it is intended the presence of elements equal to zero in the same matrix: more are present elements equal to zero
more the matrix has a greater degree of sparsity; a sparse matrix brings advantages in terms of memory occupation and computational.
As far as memory is concerned, it can be more easily compressed thanks to the presence of redundant elements (the zeros, in fact)
and this is the case treated in this post; in general, sparse arrays can also be stored in a different way from the traditional
NxM arrays, for example by storing in lists the nonzero values with their associated indexes, but this is not the case discussed in this post.
Regarding the computational aspects there could be some room for improvement as the multiplications with elements equal to zero
could be skipped, but also this case is not treated in this post, since here the focus is to realize the sparsity
of the matrices of weights in order to obtain greater compression in the face of a limited loss of quality of inference (the prediction) of the model itself.
The post presents a magnitudebased weight pruning solution implemented via the library
TensorFlow Model Optimization
and shows three examples of pruning applied to three different types of networks: a fullconnected network (an MLP that performs a regression), a longshorttermmemory network
(an LSTM network that performs a timeseries forecast) and a convolutional network (a CNN network that performs an image classifier).
The code described by this post requires version 3 of Python and uses TensorFlow 2.x technology (both CPU or GPU) with Keras (which is already integrated within TensorFlow 2);
requires in addition to the already mentioned TensorFlow Model Optimization other libraries such as NumPy, SkLearn, Pandas, MatPlotLib and TensorFlow Datasets.
To get the code please see the paragraph Download of the complete code at the end of this post.
The code of the solution
The heart of the proposed solution is the file trim_insignificant_weights.py
which implements two classes and some functions.
Classes are:

AttemptConfig
: which is the element of a search grid (implemented by the examples shown later) that proceeds by attempts; this class implements two properties: the name of the attempt and the TensorFlow Model Optimization object that implements the pruning policy. Currently the supported policies are:PolynomialDecay
andConstantSparsity
. 
AttemptInfo
: which implements a set of properties to store various information about a trained model. Precisely: the total number of weights, the number of weights equal to zero (and by difference those different from zero), the size of the original model saved in .h5, the size of the zipped .h5 file, the size of the .tflite file, the size of the zipped .tflite file with the relative compression coefficients and also other information regarding inference, such as predicted test values and the error between the predicted and expected values.

print_attempt_infos
: takes as input a list of objects of typeAttemptInfo
, obtained from the complete execution of the search grid, and writes on the standard output, in a userfriendly way, the information content of the various objectsAttemptInfo
present in the list. 
inspect_weigths
: takes in input a keras model, inspects the weights (both kernel and bias) and writes on the standard output the number of weights for each layer of the model indicating how many of them are equal to zero and how many are different from zero. 
retrieve_size_of_model
: takes in input a keras model and returns the size of the .h5 file that is obtained saving the model and the size of the same zipped file. 
retrieve_size_of_lite_model
: takes as input a keras model and returns the size of the .tflite file that is obtained by converting and saving the model for TensorFlow Lite; also, as above, returns the size of the same zipped file. 
build_pruning_model
: takes as input an original keras model (which has not undergone any pruning process) and returns a wrapper of the model by applying the methodprune_low_magnitude
to prepare the model to undergo a pruning process during the training phase. 
retrieve_callbacks_for_pruning
: returns the callbackUpdatePruningStep
needed for the pruning training phase. 
extract_pruned_model
: takes as input a wrapper for pruning a model and removes the wrapper, viastrip_pruning
and returns the underlying model that results ready for inference.
The examples
The examples shown in this post all follow the same pattern: the example code prepares a dataset (which is generated synthetically in examples #1 and #2,
while in example #3 a dataset of TensorFlow Datasets is used,
which is divided into two pieces: one for training and the other for testing (sometimes there is a third piece for validation);
then a neural network model is built and a training process is performed.
This model is named original model and information about it is placed in an instance of the class AttemptInfo
.
At this point the search grid is created, which is a list of instances of the classes PolynomialDecay
and ConstantSparsity
initialized differently; each configuration is stored in an instance of AttemptConfig
.
The code in the example loops over the search grid and for each AttemptConfig
creates a wrapper for pruning the original template
by calling the function build_pruning_model
and trains this model, using the same hyperparameters as the training of the original model
but with an extra callback obtained by calling retrieve_callbacks_for_pruning
.
During training, the model undergoes a process of pruning; once training is over, the function extract_pruned_model
is called
function to remove the wrapper and obtain the underlying model on which the inference of the test data is performed
and finally we store in a new instance of AttemptInfo
all the information about that model, in particular the size of the zipped .h5 file
and the size of the zipped .tflite file and the error calculated by comparing the prediction on the test data and the real test values.
The various instances of AttemptInfo
are collected in a list and at the end of the sample script the information from the various attempts is displayed
of the various attempts to allow you to compare the compression factor obtained for each attempt against how much loss in model quality.
For examples #1 and #2 a Cartesian graph with two curves is also shown for each attempt: in green the test dataset, in red the prediction of the current attempt
and this gives an insight into the loss of quality as the pruning activity increases.
Example #1: fullconnected neural network
The code for this example is the file example1.py
.
The dataset used by this example is a synthetic dataset generated as follows:
fx_gen_ds = lambda x: x**2 #generating function of the dataset
x_dataset = np.arange(2., 2, 0.005, dtype=float)
y_dataset = fx_gen_ds(x_dataset)
which is trivially a parabolic curve whose equation is $y=x^2$ with $x \in [2, 2]$.To execute this Python scripy run the following command:
$ python example1.py
In the output obtained, just at the beginning, we observe the structure of the model, which is a normal fullconnected network,
that is a MLP (Multi Layer Perceptron) implemented through the Dense
layer, with a total of 4289 trainable weights.Here is the network structure:
Model: "mlp_regression_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 1)] 0
_________________________________________________________________
dense (Dense) (None, 32) 64
_________________________________________________________________
dense_1 (Dense) (None, 64) 2112
_________________________________________________________________
dense_2 (Dense) (None, 32) 2080
_________________________________________________________________
dense_3 (Dense) (None, 1) 33
=================================================================
Total params: 4,289
Trainable params: 4,289
Nontrainable params: 0
_________________________________________________________________
It follows the training of the original model that consists of 100 epochs with batch size of 80 elements, with optimizer Adams
and MeanSquaredError
as loss function. The properties of this model (which we call original model because it is not pruned)
are visible in the output of the program:Model: original (unpruned)
Total number of weights: 4289
Total number of nonzero weights: 4280
Total number of zero weights: 9
Unzipped h5 size: 37272 bytes
Zipped h5 size: 18155 bytes (compression factor: 51.29%)
Unzipped tflite size: 7232 bytes
Zipped tflite size: 5781 bytes (compression factor: 20.06%)
Error (loss) value: 1.379212E04
from which we can see that the error calculated on the test dataset is very low ($1.379212 \cdot 10^{4}$), there are very few weights equal to zero (less than 10, practically none) as expected
because that model has not been pruned, the compression factor on the .h5 file is $51.29 \%$
while the one on the .tflite file is $20.06 \%$.The following image shows at a glance that the quality of inference is very good.
Then follows the execution of the search grid that performs 11 attempts of pruning application, 6 with
PolynomialDecay
variously initialized
and 5 with ConstantSparsity
variously initialized. As a sample, we show the result of one of the 11 attempts, precisely const sparsity 0.5.
(however, the results of all attempts are available in the standard output):Model: const sparsity 0.5
Total number of weights: 4289
Total number of nonzero weights: 1775
Total number of zero weights: 2514
Unzipped h5 size: 37272 bytes
Zipped h5 size: 10832 bytes (compression factor: 70.94%)
Unzipped tflite size: 7232 bytes
Zipped tflite size: 2953 bytes (compression factor: 59.17%)
Error (loss) value: 3.109528E04
from which we can see that the error calculated on the test dataset remains very low ($3.109528 \cdot 10^{4}$) even if a bit higher than the original model as expected
because the pruning reduces the quality of the model; the weights equal to zero are 2514 on 4289 (so more than half), the compression factor on the .h5 file is $70.94 \%$
while that on the .tflite file is $59.17 \%$.The following image shows at a glance that the quality of inference is quite good.
When all attempts are finished, the example script shows the recap of all attempts; the first element of the recap is relative to the original model.
*** Final recap ***
Attempt name Size h5 (Comp. %) Error (loss)
original (unpruned) 18155 ( 51.29%) 1.379212e04
poly decay 10/50 11958 ( 67.92%) 1.981978e03
poly decay 20/50 11928 ( 68.00%) 9.621178e05
poly decay 30/60 10544 ( 71.71%) 2.460897e04
poly decay 30/70 9039 ( 75.75%) 2.291273e03
poly decay 40/50 12254 ( 67.12%) 8.707970e05
poly decay 10/90 5782 ( 84.49%) 3.172360e02
const sparsity 0.1 10858 ( 70.87%) 5.406314e04
const sparsity 0.4 10856 ( 70.87%) 4.125351e04
const sparsity 0.5 10832 ( 70.94%) 3.109528e04
const sparsity 0.6 10476 ( 71.89%) 2.269561e04
const sparsity 0.9 5792 ( 84.46%) 9.419761e04
from which we deduce that in principle, as the compression factor increases, the error calculated on the test dataset increases
and therefore the quality of the inference decreases.
As a sample the following image shows a model that has suffered a heavy pruning and consequently the quality of the inference
is significantly worse than the model previously shown.Note: Given the stochastic nature of the training phase, your specific results may vary. Consider running the example a few times.
Example #2: longshorttermmemory neural network
The code for this example is the file example2.py
.
The dataset used by this example is a synthetic time series generated as follows:
ft_gen_ts = lambda t: 2.0 * np.sin(t/10.0) #generating function of the time series
t_train = np.arange(0, 200, 0.5, dtype=float)
y_train_timeseries = ft_gen_ts(t_train)
t_test = np.arange(200, 400, 0.5, dtype=float)
y_test_timeseries = ft_gen_ts(t_test)
which is trivially a sine wave whose equation is $y=2 \sin \frac{t}{10}$ with $t \in [0, 200]$.To execute this Python scripy run the following command:
$ python example2.py
In the output obtained, right at the beginning, we see the structure of the model, which is a network with an LSTM layer followed by a Dense layer,
suitable to calculate a forecast; it has with a total of 26321 trainable weights.Here is the network structure:
Model: "long_short_term_memory_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 6, 1)] 0
_________________________________________________________________
lstm (LSTM) (None, 80) 26240
_________________________________________________________________
dense (Dense) (None, 1) 81
=================================================================
Total params: 26,321
Trainable params: 26,321
Nontrainable params: 0
_________________________________________________________________
This is followed by training of the original model which consists of 80 epochs with time window equal to 6 elements and batch size of 50 elements,
with optimizer Adams
and MeanSquaredError
as loss function.
The properties of this model (which we call original model because it is not pruned)
are visible in the output of the program:Model: original (unpruned)
Total number of weights: 26321
Total number of nonzero weights: 26321
Total number of zero weights: 0
Unzipped h5 size: 122096 bytes
Zipped h5 size: 99952 bytes (compression factor: 18.14%)
Unzipped tflite size: 42480 bytes
Zipped tflite size: 29678 bytes (compression factor: 30.14%)
Error (loss) value: 1.289916E01
from which we see that the error calculated on the test forecast is relatively low ($1.289916 \cdot 10^{1}$), no weight equal to zero as expected
because that model has not been pruned, the compression factor on the .h5 file is $18.14 \%$
while that on the .tflite file is $30.14 \%$.The following image shows at a glance that the quality of the forecast is good.
Then follows the execution of the search grid that performs 11 attempts of pruning application, 6 with
PolynomialDecay
variously initialized
and 5 with ConstantSparsity
variously initialized. As a sample, we show the result of one of the 11 attempts, precisely poly decay 40/50.
(however the results of all attempts are available in the standard output):Model: poly decay 40/50
Total number of weights: 26321
Total number of nonzero weights: 13322
Total number of zero weights: 12999
Unzipped h5 size: 122096 bytes
Zipped h5 size: 63464 bytes (compression factor: 48.02%)
Unzipped tflite size: 42480 bytes
Zipped tflite size: 21277 bytes (compression factor: 49.91%)
Error (loss) value: 8.439505E01
from which we can see that the error calculated on the test dataset is still low ($8.439505 \cdot 10^{1}$) even if a bit higher than the original model as expected
because the pruning reduces the quality of the model; the weights equal to zero are 12999 on 26321 (so almost half), the compression factor on the .h5 file is $48.02 \%$
while that on the .tflite file is $49.91 \%$.The following image shows at a glance that the quality of the forecast is relatively good.
When all attempts are finished, the example script shows the recap of all attempts; the first element of the recap is relative to the original model.
*** Final recap ***
Attempt name Size h5 (Comp. %) Error (loss)
original (unpruned) 99952 ( 18.14%) 1.289916e01
poly decay 10/50 62316 ( 48.96%) 8.029442e01
poly decay 20/50 62268 ( 49.00%) 1.122432e01
poly decay 30/60 53409 ( 56.26%) 2.103334e+00
poly decay 30/70 43866 ( 64.07%) 3.067956e+00
poly decay 40/50 63464 ( 48.02%) 8.439505e01
poly decay 10/90 22154 ( 81.86%) 4.263138e+00
const sparsity 0.1 96429 ( 21.02%) 2.983370e+00
const sparsity 0.4 72670 ( 40.48%) 3.378339e+00
const sparsity 0.5 63657 ( 47.86%) 3.714817e01
const sparsity 0.6 54506 ( 55.36%) 4.406884e+00
const sparsity 0.9 22818 ( 81.31%) 4.847150e+00
from which it can be deduced that, in principle, as the compression factor increases, the error calculated on the test dataset increases and therefore the quality of the forecast decreases.
and therefore the quality of the forecast decreases.As a sample the following image shows a model that has undergone a heavy pruning and consequently the quality of the inference has notably worsened regarding the model previously shown.
Note: Given the stochastic nature of the training phase, your specific results may vary. Consider running the example a few times.
Example #3: convolutional neural network
The code for this example is the file example3.py
.
The dataset used by this example is the dataset Flowers of TensorFlow;
here is the code that performs the download of that dataset:
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
total_number_of_training_images = 8
total_number_of_validation_images = 6
total_number_of_test_images = 7
To execute this Python scripy run the following command:$ python example3.py
It is recommended to run this script having available a GPU, even a modest one, as running on CPU can take a long time.In the output obtained, right at the beginning, we see the structure of the model, which is a network with a series of Conv2D and MaxPooling2D layers followed by a Dense layer, then a Flatten, then a Dropout (to avoid overfitting) and finally a Dense; has with a total of 1658565 trainable weights.
Here is the structure of the network:
Model: "cnn_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 180, 180, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 178, 178, 32) 896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 89, 89, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 87, 87, 32) 9248
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 43, 43, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 41, 41, 32) 9248
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 20, 20, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 12800) 0
_________________________________________________________________
dense (Dense) (None, 128) 1638528
_________________________________________________________________
dropout (Dropout) (None, 128) 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 645
=================================================================
Total params: 1,658,565
Trainable params: 1,658,565
Nontrainable params: 0
_________________________________________________________________
This is followed by training the original model which consists of 15 epochs with batch size of 100 elements,
with optimizer Adams
and SparseCategoricalCrossentropy(from_logits=True)
as loss function.
Note: The error value shown in the output of the program however is not the value of the loss function obtained by comparing the test dataset with the prediction,
as it was for examples #1 and #2, but is the number of correctly classified images divided by the total number of images in the test dataset.The properties of this model (which we call the original model because it is not pruned) are visible in the output of the program:
Model: original (unpruned)
Total number of weights: 1658565
Total number of nonzero weights: 1658565
Total number of zero weights: 0
Unzipped h5 size: 6665816 bytes
Zipped h5 size: 6151486 bytes (compression factor: 7.72%)
Unzipped tflite size: 1669728 bytes
Zipped tflite size: 1358843 bytes (compression factor: 18.62%)
Error (loss) value: 3.950954E01
from which we see that the calculated error on the classification on the test dataset is $3.950954 \cdot 10^{1}$, no weight equal to zero as expected
because that model has not been pruned, the compression factor on the .h5 file is $7.72 \%$
while that on the .tflite file is $18.62 \%$.Then follows the execution of the search grid that performs 11 attempts of pruning application, 6 with
PolynomialDecay
variously initialized
and 5 with ConstantSparsity
variously initialized. As a sample, we show the result of one of the 11 attempts, precisely poly decay 40/50.
(however the results of all attempts are available in the standard output):Model: poly decay 40/50
Total number of weights: 1658565
Total number of nonzero weights: 829624
Total number of zero weights: 828941
Unzipped h5 size: 6665816 bytes
Zipped h5 size: 3859572 bytes (compression factor: 42.10%)
Unzipped tflite size: 1669728 bytes
Zipped tflite size: 978984 bytes (compression factor: 41.37%)
Error (loss) value: 4.087193E01
from which we can see that the error calculated on the test dataset is still close to that of the original model ($4.087193 \cdot 10^{1}$) even if a bit higher as expected
because the pruning reduces the quality of the model; the weights equal to zero are 828941 on 1658565 (so almost half), the compression factor on the .h5 file is $42.10 \%$
while that on the .tflite file is $41.37 \%$.When all attempts are finished, the example script shows the recap of all attempts; the first element of the recap is relative to the original model.
*** Final recap ***
Attempt name Size h5 (Comp. %) Error (loss)
original (unpruned) 6151486 ( 7.72%) 3.950954e01
poly decay 10/50 3781226 ( 43.27%) 4.223433e01
poly decay 20/50 3786548 ( 43.19%) 4.168937e01
poly decay 30/60 3238918 ( 51.41%) 4.523161e01
poly decay 30/70 2602091 ( 60.96%) 4.441417e01
poly decay 40/50 3859572 ( 42.10%) 4.087193e01
poly decay 10/90 1286371 ( 80.70%) 4.741144e01
const sparsity 0.1 5891419 ( 11.62%) 4.604905e01
const sparsity 0.4 4400747 ( 33.98%) 4.386921e01
const sparsity 0.5 3830963 ( 42.53%) 4.414169e01
const sparsity 0.6 3246370 ( 51.30%) 4.441417e01
const sparsity 0.9 1285383 ( 80.72%) 4.659401e01
from which it can be deduced that in principle, as the compression factor increases, the error calculated on the test dataset increases
and therefore the quality of the classification decreases.Note: Given the stochastic nature of the training phase, your specific results may vary. Consider running the example a few times.
Download of the complete code
The complete code is available at GitHub.
These materials are distributed under MIT license; feel free to use, share, fork and adapt these materials as you see fit.
Also please feel free to submit pullrequests and bugreports to this GitHub repository or contact me on my social media channels available on the top right corner of this page.