Visualising Tensorflow model summaries
It's no secret that my PhD is based in machine learning / AI (my specific title is "Using Big Data and AI to dynamically map flood risk"). Recently a problem I have been plagued with is quickly understanding the architecture of new (and old) models I've come across at a high level. I could read the paper a model comes from in detail (and I do this regularly), but it's much less complicated and much easier to understand if I can visualise it in a flowchart.
To remedy this, I've written a neat little tool that does just this. When you're training a new AI model, one of the things that it's common to print out is the summary of the model (I make it a habit to always print this out for later inspection, comparison, and debugging), like this:
model = setup_model()
model.summary()
This might print something like this:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 500, 16) 160000
lstm (LSTM) (None, 32) 6272
dense (Dense) (None, 1) 33
=================================================================
Total params: 166,305
Trainable params: 166,305
Non-trainable params: 0
_________________________________________________________________
(Source: ChatGPT)
This is just a simple model, but it is common for larger ones to have hundreds of layers, like this one I'm currently playing with as part of my research:
(Can't see the above? Try a direct link.)
Woah, that's some model! It must be really complicated. How are we supposed to make sense of it?
If you look closely, you'll notice that it has a Connected to
column, as it's not a linear tf.keras.Sequential()
model. We can use that to plot a flowchart!
This is what the tool I've written generates, using a graphing library called nomnoml. It parses the Tensorflow summary, and then compiles it into a flowchart. Here's a sample:
It's a purely web-based tool - no data leaves your client. Try it out for yourself here:
https://starbeamrainbowlabs.com/labs/tfsummaryvis/
For the curious, the source code for this tool is available here: