Doing a 3-way dataset split in Tensorflow // PhD Aside 3
Heya!
It's been a while since I talked about my PhD, so I wanna change that today!
Teaching post coming just as soon as I find the energy to write it.
Anyway, one of the things I have been asked to do as part of my thesis corrections is to split my dataset into 3 parts and then run some random experiments.
This seemed like an odd request to me when there are more important things I need to be doing to stablise the models I'm training, but in the process of implementing support for this in a number of my models I hit upon a snag:
Tensorflow doesn't have native support for 3-way dataset splits!
If Tensorflow (tf) doesn't have support for it, then clearly it can't be that important :P
In all seriousness though, it did mean that I needed a solution and after searching around and getting nowhere very quickly (you really don't need to go to all this trouble to solve this one), I decided that something must be done!
So, I ended up implementing a thing that I thought was rather cool, so I thought I'd share it here.
Other parts in this PhD aside series:
- PhD Aside 2: Jupyter Lab / Notebook First Impressions
- PhD Aside: Reading a file descriptor line-by-line from multiple Node.js processes
More coming soon as I get distracted and find cool shiny things while doing my PhD! If you want more PhD-related posts, do check the PhD tag here on my blog.
Splitting into multiple pieces
As a quick reminder, when training an AI we usually split the dataset we're training it on into 2 parts:
- Training (usually ~80%)
- Validation (usually ~20%)
....the model only gets to learn from the training data, and we hold the validation data back so we can test the model on it later. If it does a lot worse on the validation dataset than on the training dataset, then we can reasonably conclude that the model isn't generalising very well to new data/samples/etc/ it hasn't seen before.
However, sometimes people decide it's a great idea to split a dataset into 3 parts rather than 2.
The reasoning behind this - as far as I'm aware, is that while you're working on optimising your model by trying different architectures, hyperparameters, etc, you are in a sense optimising for the model's performance on the validation dataset.
So, to this end it is suggested that by having a third split - called the test split, one can evaluate the model again and make really really sure how well it's generalising (or not) to new data.
This is just as much as I know at the moment and it doesn't really make sense in my head, so if you have a better way of explaining it please do leave a comment below.
Yeah, Tensorflow sucks
.....maybe not really (though CUDA / GPU support does suck very much >_<). The thing to remember about AI model frameworks like Tensorflow is that data loading efficiency is everything. A tangent for another time perhaps (is your model training slowly? then this is most likely your issue!), but what matters here is that the signature of tf.keras.Model.fit() (the function that actually trains the shiny new model you've just created / loaded from disk / etc) looks a bit like this:
model.fit(
x, # probably either a tf.Tensor or a tf.data.Dataset
y, # probably the same as x immediately above
validation_split=0.0, # Percentage of x to treat as the validation dataset
validation_data=None, # The validation dataset - see above
callbacks=None, # A list of functions to call at different times - this will become important later
# .....
)
....this is just terrible design, if you ask me. x and y here are the input(s) and ground truth labels respectively.
Single letter variable names should not be allowed!
Unfortunately, the only options we have for inputting another dataset to calculate metrics on (e.g. cross-validation, dice coefficient, intersection-over-union - IoU, etc etc etc) is to pass a validation dataset - there's no option to pass e.g. a list() of datasets to evaluate on, which is a shame as it would almost make sense.
Cracking the glass
So, we're at a loss right? What to do? We have 3 nice neat tf.data.Datasets ready do go and no way to ensure we get metrics calculated reliably on the 3rd one.
The solution here that I came up with is to write a custom callback that manually iterates over the dataset and calculates the metrics, before then sneakily appending them to TF's main metrics log system so that TF never suspects a thing, and writes them out along with all the other metrics for us :D
Learning to work within the framework of your choice is a key part of the process. Don't just try and hack your way around it - most frameworks - TF included - provide many different ways to manipulate tensors etc if you learn how they work.
In this case, we know that Tensorflow has a callback system, because TF ships with a number of default callbacks like tf.keras.callbacks.CSVLogger (which can also output as TSV, my favourite file format when I'm not using jsonl).
Writing a custom callback is a 2 step process:
- Write a custom class that inherits from
tf.keras.callbacks.Callback - Instantiate an instance of our new class and pass it to
tf.keras.Model.fit()
Let's go through these 1 by 1.
Writing a callback
As mentioned, inheriting from tf.keras.callbacks.Callback is the aim of the game here. We can make it really quite lightweight and do it within 10 lines of code:
import tensorflow as tf
class CallbackExtraValidation(tf.keras.callbacks.Callback):
def __init__(self):
super(CallbackExtraValidation, self).__init__()
pass
def on_epoch_end(self, epoch, logs=None):
pass
The way callbacks work in TF is that they are a class with a bunch of methods. The main tf.keras.callbacks.Callback class defines a bunch of empty methods, and then you override the ones you're interested in.
Some examples of things you can add a callback for:
- The start/end of training as a whole
- The start/end of every batch (sometimes called a step, but a step cans ometimes mean multiple batches at once)
- The start/end of every epoch (which most of the time but not always is before/after the entire dataset has been seen by the model once)
- The start/end of validation
....and so on.
You get the picture: it's a way that you can add simple hooks that let you run custom functions that do stuff at a time of your choosing.
In our case, we know that we wanna evaluate our model on an extra dataset at the end of every epoch, so we added a method for that.
Check the docs for a comprehensive look at the possible functions you can override here:
https://devdocs.io/tensorflow~2.9/keras/callbacks/callback
Now we have a custom function running when we want it to, it's just a case of grabbing the dataset(s) and evaluating the model with em:
import tensorflow as tf
class CallbackExtraValidation(tf.keras.callbacks.Callback):
def __init__(self, datasets, verbose="auto"):
super(CallbackExtraValidation, self).__init__()
self.datasets = datasets # Dictionary in the form { string: tf.data.Dataset }, where `string` here is the name of the dataset
self.verbose = verbose
def on_epoch_end(self, epoch, logs=None):
for name, dataset in self.datasets.items():
metrics = self.model.evaluate(
dataset,
verbose=self.verbose,
return_dict=True
)
Excellent! See, we can even throw in enumerating over a list of datasets with minimal effort (I'm looking at you, Tensorflow!).
Okay, so we have metrics by cheesing the issue and asking TF to evaluate a model. We even have a reference to the model in question provided for us by Tensorflow - how very kind!
....but how do we hoodwink Tensorflow and slip the extra metrics we've calculated into the main metrics stream without it noticing?
The solution has actually been provided by Tensorflow itself: The logs argument that is passed to on_epoch_end is the metrics log for that current epoch, so we can just update it.
We don't wanna overwrite any of the existing metrics though, so the solution is just to prepend a static string to the name of each metric as we copy it over:
for metric_name, metric_value in metrics.items():
logs[f"{name}_{metric_name}"] = metric_value
See? Easy peasy!
That's all there is to it.
Of course, a few additional checks need to be added because I'm paranoid and Tensorflow sometimes passes None instead of the logs instance for a giggle, but it does work reliably.
Find the full version of this class with those checks here:
Find the full version with updates (though the link miiight break if I rejig the repo's directory structure or rename the class etc):
How to pass to .fit()?
Just for completeness, this is how you pass the above in .fit():
model.fit(
dataset_input,
dataset_labels,
callbacks = [
CallbackExtraValidation({
"test": my_extra_dataset
})
]
)
..in short, the callbacks argument of tf.keras.Model takes a list() of tf.keras.callbacks.Callback instances. Simply instantiate an instance of your custom callback class, and you're away, just as if it were an official part of the Tensorflow framework!
Conclusion
This has been a not-so-quick post (I swear I meant this to be shorter.....!) about doing a 3-way dataset split in Tensorflow because Tensorflow doesn't support it natively.
I hope you found this useful - let me know if you have any other questions or comments - I'm happy to answer them - both in the comments below and as subjeccts for future blog posts!
The end of the teaching is on the horizon (though I'm having to do a teaching course which is meh because it's eating into my time again) - so I'm hoping that my energy levels will start to recover once I've found a new rhythm for this new semester that is just starting.
Reddit










