Tips for training (large numbers of) AI models

As part of my PhD, I'm training AI models. The specifics as to what for don't particularly matter for this post (though if you're curious I recommend my PhD update blog post series). Over the last year or so, I've found myself training a lot of AI models, and dealing with a lot of data. In this post, I'm going to talk about some of the things I've found helpful and some of the things things I've found that are best avoided. Note that this is just a snapshot of my current practices now - this will probably gradually change over time.

I've been working with Tensorflow.js and Tensorflow for Python on various Linux systems. If you're on another OS or not working with AI then what I say here should still be somewhat relevant.


First up: a quick word on datasets. While this post is mainly about AI models, datasets are important too. Keeping them organised is vitally important. Keeping all the metadata that associated with them is also vitally important. Keeping a good directory hierarchy is the best way to achieve this.

I also recommend sticking with a standard format that's easy to parse using your preferred language - and preferably lots of other languages too. Json Lines is my personal favourite format for data - potentially compressed with Gzip if the filesize of is very large.

AI Models

There are multiple facets to the problem of wrangling AI models:

  1. Code that implements the model itself and supporting code
  2. Checkpoints from the training process
  3. Analysis results from analysing such models

All of these are important for different reasons - and are also affected by where it is that you're going to be training your model.

By far the most important thing I recommend doing is using Git with a remote such as GitHub and committing regularly. I can't stress enough how critical this is - it's the best way to both keep a detailed history of the code you've written and keep a backup at the same time. It also makes working on multiple computers easy. Getting into the habit of using Git for any project (doesn't matter what it is) will make your life a lot easier. At the beginning of a programming session, pull down your changes. Then, as you work, commit your changes and describe them properly. Finally, push your changes to the remote after committing to keep them backed up.

Coming in at a close second is implementing is a command line interface with the ability to change the behaviour of your model. This includes:

This is invaluable for running many different variants of your model quickly to compare results. It is also very useful when training your model in headless environments, such as on High Performance Computers (HPCs) such as Viper that my University has.

For HPCs that use Slurm, a great tip here is that when you call sbatch on your job file (e.g. sbatch path/to/jobfile.job), it will preserve your environment. This lets you pass in job-specific parameters by writing a script like this:

#!/usr/bin/env bash
#SBATCH -n 4
#SBATCH --gres=gpu:1
#SBATCH -o %j.%N.%a.out
#SBATCH -e %j.%N.%a.err
#SBATCH -p gpu05,gpu
#SBATCH --time=5-00:00:00
#SBATCH --mem=25600
# 25600 = 25GiB memory required

# Viper use Trinity ClusterVision: and
module load utilities/multi
module load readline/7.0
module load gcc/10.2.0
module load cuda/11.5.0

module load python/anaconda/4.6/miniconda/3.7

echo ">>> Installing requirements";
conda run -n py38 pip install -r requirements.txt;
echo ">>> Training model";
/usr/bin/env time --verbose conda run -n py38 src/ ${PARAMS}
echo ">>> exited with code $?";

....which you can call like so:

PARAMS="--size 4 --example 'something else' --input path/to/file --output outputs/20211002-resnet" sbatch path/to/jobfile.job

You may end up finding you have rather a lot of code behind your model - especially for data preprocessing depending on your dataset. To handle this, I go by 2 rules of thumb:

  1. If a source file of any language is more than 300 lines long, it should be split into multiple files
  2. If a collection of files do a thing together rather nicely, they belong in a separate Git repository.

To elaborate on these, having source code files become very long makes them difficult to maintain, understand, and re-use in future projects. Splitting them up makes your life much easier.

Going further, modularising your code is also an amazing paradigm to work with. I've broken many parts of my various codebases I've implemented for my PhD out as open-source projects on npm (the Node Package Manager) - most notably applause-cli, terrain50, terrain50-cli, nimrod-data-downloader, and twitter-academic-downloader.

By making them open-source, I'm not only making my research and methods more transparent and easier for others to independently verify, but I'm also allowing others to benefit from them (and potentially improve them) too! As they say, there's no need to re-invent the wheel.

Eventually, I will be making the AI models I'm implementing for my PhD open-source too - but this will take some time as I want to ensure that the models actually work before doing so (I've got 1 model I implemented fully and documented too, but in the end it has a critical bug that means the whole thing is useless.....).

Saving checkpoints from the training process of your model is also essential. I recommend doing so at the end of each epoch. As part of this, it's also useful to have a standard format for your output artefacts from the training process. Ideally, these artefacts can be used to identify precisely what dataset and hyperparameters that model and checkpoints were trained with.

At the moment, my models output something like this:

+ output_dir/
    + summary.txt       Summary of the layers of the model and their output shapes
    + metrics.tsv       TSV file containing training/validation loss/accuracy and epoch numbers
    + settings.toml     The TOML settings that the model was trained with
    + checkpoints/      Directory containing the checkpoints - 1 per epoch
        + checkpoint_e1_val_acc0.699.hdf5   Example checkpoint filename [Tensorflow for Python]
        + 0/            OR, if using Tensorflow.js instead of Tensorflow for Python, 1 directory per checkpoint
    + this_run.log      Logfile for this run [depends on where the program is being executed]

settings.toml leads me on to settings files. Personally I use TOML for mine, and I use 2 files:

Having a config file is handy when you have multiple dataset input files that rarely change. Generally speaking you want to ensure that you minimise the number of CLI arguments that you have to specify when running your model, as then it reduces cognitive load when you're training many variants of a model at once (I've found that wrangling dozens of different dataset files and model variants is hard enough to focus on and keep organised :P).

Analysis results are the final aspect here that it's important to keep organised - and the area in which I have the least experience. I've found it's important to keep track of which model checkpoint it was that the analysis was done with and which dataset said model was trained on. Keeping the entire chain of dataflow clear and easy to follow is difficult because the analysis one does is usually ad-hoc, and often has to be repeated many times on different model variants.

For this, so far I generate statistics and some graphs on the command line. If you're not already familiar with the terminal / command line of your machine, I can recommend checking out my earlier post Learn Your Terminal, which has a bunch of links to tutorials for this. In addition, jq is an amazing tool for manipulating JSON data. It's not installed by default on most systems, but it's available in most default repositories and well worth the install.

For some graphs, I use Gnuplot. Usually though this is only for more complex plots, as it takes a moment to write a .plt file to generate the graph I want in it.

I'm still looking for a good tool that makes it easy to generate basic graphs from the command line, so please get in touch if you've found one.

I'm also considering integrating some of the basic analysis into my model training program itself, such that it generates e.g. confusion matrices automatically as part of the training process. matplotlib seems to do the job here for plotting graphs in Python, but I have yet to find an equivalent library for Javascript. Again, if you've found one please get in touch by leaving a comment below.


In this post, I've talked about some of the things I've found helpful so far while I've been training models. From using Git to output artefacts to implementing command line interfaces and wrangling datasets, implementing the core AI model itself is actually only a very small part of an AI project.

Hopefully this post has given you some insight into the process of developing an AI model / AI-powered system. While I've been doing some of these things since before I started my PhD (like Git), others have taken me a while to figure out - so I've noted them down here so that you don't have to spend ages figuring out the same things!

If you've got some good tips you'd like to share on developing AI models (or if you've found the tips here in this blog post helpful!), please do share them below.

Tag Cloud

3d 3d printing account algorithms android announcement architecture archives arduino artificial intelligence artix assembly async audio automation backups bash batch blender blog bookmarklet booting bug hunting c sharp c++ challenge chrome os cluster code codepen coding conundrums coding conundrums evolved command line compilers compiling compression containerisation css dailyprogrammer data analysis debugging demystification distributed computing docker documentation downtime electronics email embedded systems encryption es6 features ethics event experiment external first impressions freeside future game github github gist gitlab graphics hardware hardware meetup holiday holidays html html5 html5 canvas infrastructure interfaces internet interoperability io.js jabber jam javascript js bin labs learning library linux lora low level lua maintenance manjaro minetest network networking nibriboard node.js operating systems own your code pepperminty wiki performance phd photos php pixelbot portable privacy problem solving programming problems project projects prolog protocol protocols pseudo 3d python reddit redis reference release releases rendering resource review rust searching secrets security series list server software sorting source code control statistics storage svg systemquery talks technical terminal textures thoughts three thing game three.js tool tutorial tutorials twitter ubuntu university update updates upgrade version control virtual reality virtualisation visual web website windows windows 10 worldeditadditions xmpp xslt


Art by Mythdael