The neural network is not training and the code looks fine but you are not sure what is going. This is unfortunately a common behind-the-scenes experience for many people working with neural networks and I’ve not only personally experienced it but also helped others in similar scenarios. Following these 5 steps really helps mitigate training deadlocks and improve the overall machine learning process.

1. Visualise the data

Visualising the data available is absolutely crucial to not only ensure that it makes sense but also what the model should be doing. Try to visualise the actual arrays or tensors that go into the model including any data transformation pipelines such as random crops and rotations. Then try to answer some of these questions:

  • Does it look like what you expected? This is very obvious but I cannot emphasise how common it is to find something fishy. One time we realised that the data transformation pipeline was cropping invalid regions of the input images and just returning a tensor of zeroes! You could imagine how that resulted in the model not training at all.
  • Are the random samples of the data consistent? Don’t just stop at a single data point. Re-run your script to sample random data points and check if they look similar.

2. Predict with random weights

Once you are confident the data is correct and the neural network is initialised, just let it predict on a random batch. Does it actually run and give a prediction that looks reasonable? By reasonable I mean, for classification, does it return random classes? If you run the network and you realise it already outputs a specific class or it gives an error, there is no point in running it over the entire dataset on multiple GPUs only to realise it does not work.

Repeat this process with smaller batches, larger batches, random samples and data points you think are edge cases. For example, for inputs of different lengths in a sequence prediction task, you can check whether the varying lengths padded and handled correct. The aim is to see if the model runs.

3. Over-fit to a single batch

Now that the model seems to be working, we need to probe the training loop. Let’s a single batch and hit the train button. The immediate thing we are looking for is: does the initial loss make sense and does it go down after weight updates? Since this is a single batch, it should be fairly quick and easy to see whether the training loop can actually over-fit, i.e. learn,.

You really want to get a sense that if we give the training loop model, it should behave as expected. There might be bugs specific to the full dataset such as running out of memory but at least we are somewhat confident the core training loop works.

4. Collect wrong predictions

The model is training, the loss is going down and things are looking good. Even if you are training a convolutional neural network on the MNIST dataset, there will be data points for which it cannot correctly predict. Which ones are these? Pick them out, isolate them and use the visualisation from step 1 to see them.

Depending on the model, visualise its inner components. Attention maps are a really common one. Which elements is the model attending to in these failure cases will be very insightful. For example, we realised once that the attention layer in one of neural networks would work fine in the data points in can correct but attended to random parts of the input in the failure cases. Was the attention layer the reason for failure or something else was causing the attention layer to fail? These are some good questions that will guide your design choices and debugging.

5. Organise results

Finally, use an experiment tracking platform such as MLFlow, Weights and Biases or Sacred to organise your iterations. You will train many models, tune them, re-try, fail, succeed but think you failed and for everything in between there needs to be an organised way of recording that journey. Experiment tracking not only allows you to collect and record metrics in a centralised fashion but also lets you investigate and compare variations of your models. You don’t need to overboard and use every feature of your chosen platform such as model saving, hyper-parameter tuning. Pick features that works for you rather than you working and writing extra code for the platform.

Some of these steps such as visualising the data and using an experiment tracking platform are applicable to any machine learning based project. If you can see what your model should be doing, how different variations of your networks compare with each other, then you get a good sense of what might be going wrong and how you can improve it.