Tensorflow and PyTorch compared

Hey there! Since I've used both Tensorflow and PyTorch a bit now, I thought it was time to write a post comparing the two and their respective strengths and weaknesses.

For reference, I've used Tensorflow both for Javascript (less popular) and for Python (more popular) for a number of different models, relating to both my rainfall radar and social media halves to my PhD. While I definitely have less experience with PyTorch, I feel like I have a good enough grasp on it to get a first impression.

Firstly, let's talk about how PyTorch is different from Tensorflow, and what Tensorflow could learn from the former. The key thing I noticed about PyTorch is that it's easily the more flexible of the two. I'm pretty sure that you can create layers and even whole models that do not explicitly define the input and output shapes of the tensors they operate on - e.g. using CNN layers. This gives them a huge amount of power for handling variable sized images or sentences without additional padding, and would be rather useful in Tensorflow - where you must have a specific input shape for every layer.

Unfortunately, this comes at the cost of complexity. Whereas Tensorflow has a .fit() method, in PyTorch you have to implement it yourself - which, as you can imagine - result in a lot of additional code you have to write and test. This was quite the surprise to me when I first used PyTorch!

The other thing I like about PyTorch is the data processing pipeline and it's simplicity. It's easy to understand and essentially guides you to the most optimal solution all on it's own - leading to greater GPU usage, faster model training times, less waiting around, and tighter improve → run → evaluate & inspect → repeat loops.

While in most cases you need to know the number of items in your dataset in advance, this is not necessarily a bad thing - as it gently guides you to the realisation that by changing the way your dataset is stored, you can significantly improve CPU and disk utilisation by making your dataset more amenable to be processed in parallel.

Tensorflow on the other hand has a rather complicated data processing pipeline with multiple ways to do things and no clear guidance I could easily find on building a generic data processing pipeline that didn't make enormous assumptions like "Oh, you want to load images right? Just use this function!" - which really isn't helpful when you want to do something unusual for a research project.

Those tutorials I do find suggest you use a generator function, which can't be parallelised and makes training a model a slow and painful process. Things aren't completely without hope though - Tensorflow has a .map() method on their Dataset objects and also have a .interleave() method (if I recall correctly) to interleave multiple Dataset objects together - which I believe is a relatively recent addition. This is quite a clever way of doing things, if a bit more complicated than PyTorch's solution.

It would be nice though if the feature for automatically managing the number of parallel workers to use when parallelising things was more intelligent. I recently discovered that it doesn't max out my CPU if I have multiple .map() calls I parallelise for example, when it really should look at the current CPU usage and notice that the CPU is sitting e.g. 50% idle.

Tensorflow for Python has a horrible API more generally. It's a confusing mess as there's both Tensorflow and the inbuilt Keras, which means that it's not obvious where that function you need is - or, indeed, which version thereof you want to call. I know it's a holdover from when Keras wasn't bundled with Tensorflow by default, but the API really should be imagined and tf.keras merged into the main tf namespace somehow.

It can also be unclear when you mix Tensorflow Tensors, numpy arrays and numbers, and plain Python numbers. In some cases, it's impossible to tell where one begins and the other ends, which can be annoying since they all behave differently, so you can in some cases get random error messages when you accidentally mix the types (e.g. "I want a Tensor, not a numpy array", or "I want a plain Python number, not a numpy number").

A great example of what's possible is demonstrated by Tensorflow's own Javascript bindings - to a point. They are much better organised than the Python library for Tensorflow, although they require explicit memory management and disposal of Tensors (which isn't necessarily a bad thing, though it's difficult to compare performance improvements without comparing apples and oranges).

The difficulties start though if you want to do anything in even remotely uncharted territory - Tensorflow.js doesn't have a very wide selection of layers like the Python bindings do (e.g. multi-headed attention). It also seems to have some a number of bugs, meaning you can't just port code from the Python bindings and expect it to work. For example, I tried implementing an autoencoder, but found that that it didn't work as I wanted it to - and for the life of me I couldn't find the bug at all (despite extensive searching).

Another annoyance with Tensorflow.js is that the documentation for exactly which CUDA version you need is very poor - and sometimes outright wrong! In addition, there's no table of versions and associated CUDA + CuDNN versions required like there is for Tensorflow for Python.

It is for these reasons that I find myself using Python much more regularly - even if I dislike Python as a language and ecosystem.

At some point, I'd love to build a generic Tensor library on top of GPU.js. It would naturally support almost any GPU (since GPU.js isn't limited to CUDA-capable devices like Tensorflow is - while you can recompile it with support for other GPUs, I don't recommend it unless you have lots of time on your hands), be applicable to everything from machine to simulation to cellular automata, and run in server, desktop, and browser environments with minimal to no changes to your codebase!


There's no clear answer to whether you should use PyTorch or Tensorflow for your next project. As a rule of thumb, I suggest starting in Tensorflow due to the reduced boilerplate code, and use PyTorch if you find yourself with a wacky model that Tensorflow doesn't like very much - or you want to use a pretrained model that's only available in one or the other.

Having said this, I can certainly recommend experiencing both libraries, as there are valuable things to be learnt from both frameworks. Unfortunately, I can't recommend Tensorflow.js for anything more than basic tensor manipulations (which it is very good at, despite supporting only a limited range of GPUs without recompilation in Node.js) - even though it's API is nice and neat (and the Python bindings should take significant inspiration from it).

In the near future - one way or another - I will be posting about contrastive learning here soon. It's very cool indeed - I just need to wrap my head around and implement the loss function....

If you have experience with handling matrices, please get in touch as I'd really appreciate some assistance :P

Tag Cloud

3d 3d printing account algorithms android announcement architecture archives arduino artificial intelligence artix assembly async audio automation backups bash batch blender blog bookmarklet booting bug hunting c sharp c++ challenge chrome os cluster code codepen coding conundrums coding conundrums evolved command line compilers compiling compression containerisation css dailyprogrammer data analysis debugging demystification distributed computing dns docker documentation downtime electronics email embedded systems encryption es6 features ethics event experiment external first impressions freeside future game github github gist gitlab graphics hardware hardware meetup holiday holidays html html5 html5 canvas infrastructure interfaces internet interoperability io.js jabber jam javascript js bin labs learning library linux lora low level lua maintenance manjaro minetest network networking nibriboard node.js open source operating systems optimisation own your code pepperminty wiki performance phd photos php pixelbot portable privacy problem solving programming problems project projects prolog protocol protocols pseudo 3d python reddit redis reference release releases rendering resource review rust searching secrets security series list server software sorting source code control statistics storage svg systemquery talks technical terminal textures thoughts three thing game three.js tool tutorial tutorials twitter ubuntu university update updates upgrade version control virtual reality virtualisation visual web website windows windows 10 worldeditadditions xmpp xslt


Art by Mythdael