OCTOBER 2023, by STIJN WOESTENBORGHS and CEDRIC VAN HECK

MOJO DOES GIVE SUPERPOWERS


          Machine Learning                              gradi-mojo


TLDR; Implementation of a simple gradient descent problem in Python, Numpy, JAX, C++ (binding with Python) and Mojo. My goal is to make a fair evaluation on the out-of-the-box, raw performance of a tech stack choice. Neither of the implementations is optimal. But what I hope to show is what execution speeds to expect out of the box, the complexity of each implementation and to pinpoint which ones have the possibility of squeezing out every bit of performance the hardware has to offer.

Faster than Python, faster than Numpy, faster than JAX and faster than C++, is my experience when testing out Mojo for the first time. Seeing all the disruptions in the LLM space lately has been very interesting and it is not stopping. Multimodal LLM’s are about to be released and the idea of a set of AI agents that prompt themselves is an interesting thought experiment in itself. What is not obvious at first sight is that all these advancements in the AI space are also powered by the ongoing disruptions underneath. And Mojo is definitely one of those. 

A NEW PROGRAMMING LANGUAGE, WHY?


My roots are in machine learning on embedded devices and I still love the journey to deploy AI models and optimization algorithms on edge. In this kind of setting you absolutely care about performance and because of the lack of performance in Python I have been pushed to C++ or other alternatives multiple times. Every time it increases the complexity of a project drastically. Because once you realize that a ‘Python prototype’ isn’t going to cut it, you have to revise the tech stack completely, rewrite portions of your code or make complex bridges with some other language (like most high performance Python libraries do). But I also absolutely love Python: the efficient coding environment, short development cycles and easy evaluation/visualization of performance of the model is what gets your application deployed faster. And testing ideas quickly is key.


Then, a couple of months ago, Modular announced Mojo, a new programming language for AI developers that is on its journey to become a full superset of Python. It promises the usability of Python with the performance of C, while still having access to the complete Python ecosystem of libraries. As that doesn’t sound like the holy grail already, this blogpost series intrigued me to test it myself.

GRADIENT DESCENT


The example implements a simple Numerical Optimization algorithm as this is one of the central techniques in Machine Learning. All the code for this post is available on github, so if you want, you are able to redo the experiments and create the visualizations yourself.

Imagine having a drone swarm and you want to figure out their x, y and z coordinates of all of them. But you are only able to measure the distance between all the drones by calculating how long it takes to send and receive messages between each of them. 

Using gradient descent, one can solve for the coordinates of all the drones by minimizing loss function that punishes the difference between the estimation and your measurement. 


As the main goal here is to compare the performance of programming languages, I’m skipping over this quickly here but a detailed explanation on how it works is at the end. One important note that should be made here already is that the gradient was calculated using symbolic differentiation. This is different from automatic differentiation (the focus of JAX and what PyTorch & Tensorflow are using). But it makes the example more simple and more easy to implement.

 

IMPLEMENTATIONS


My goal is to make a fair evaluation on the out-of-the-box, raw performance of a tech stack choice. Neither of the implementations is optimal. But what I hope to show is what execution speeds to expect out of the box, the complexity of each implementation and to pinpoint which ones have the possibility of squeezing out every bit of performance the hardware has to offer. That being said, these are the different ways the example was implemented:


  • Python native: Only using libraries from the Python standard library. Although the python ecosystem is huge. It shows that these more performant libraries are using another higher formant language like C++ underneath. These libraries have taken the complexity away from you. 
  • Python numpy: Well known Python library for matrix operations using C/Fortran in the backend. 
  • JAX: XLA compiled Python code that is using JIT (just in time) compilation for high performance machine learning projects. JAX is great at automatic differentiation and compiling for CPU as well as other accelerators GPU & TPUs. So it’s important to take into consideration that we did calculate the gradient symbolically and are only running on CPU. 
  • C++ binding (Eigen): A custom setup C++ binding to Python. Calling the function from Python but using the C++ eigen-3.4.0 library underneath to carry out the gradient descent task. (Compiled with g++ and O3 compiler optimizations.) If you are interested in how this works: I have another blogpost about just that.
  • Mojo 🔥
  • C++ binding (parallel gradient): C++ allows for additional performance optimizations. The gradient descent algorithm allows calculating the gradient in parallel.
  • Mojo (parallel gradient): Same thing, Mojo allows for additional performance optimizations as well. The gradient can be calculated in parallel.



After implementing them, and running some tests with an increasing order of the problem and a consitent 1000 iterations each, this was the result. Specifically an increasing order here means increasing the number of points and an increasing dimensionality from 2D to 3D. Amazing to see that Mojo came out on top, every single time. And it was not just a little better. Note the logarithmic scale on the graph below.

The parallelized implementations were the fastest except for low orders (N=10).  The machine I tested on had 20 logical CPUs available. I guess it doesn’t really make sense to fire 20 threads if the gradient only has 10 rows. Meaning, there is only a possibility of calculating 10 operations in parallel. 


In the higher order occasions (N=1000) Mojo outperformed Python by a factor of ~6000+ and Numpy by a factor of ~2500+. But to be fair we are not doing machine learning in native Python or Numpy today. So leaving these cases out and rescaling to a Linear y-scale again results in this.

 

Having a look at some code snapshots gives an idea on how easy/complex it was to get there. It is almost incredible that Mojo is that much better, this easy. Giving more and more people the power to access the full potential of the hardware is transformational. Just imagine what it can do to the server cost of bigger and smaller AI companies. I guess Mojo does give you superpowers.


Python Native
Python Native
Python Numpy
Python Numpy
snip-JAX
snip-JAX
Mojo
Mojo
Mojo (parallel gradient)
Mojo (parallel gradient)
C++ (Eigen)
C++ (Eigen)
C++ (parallel gradient)
C++ (parallel gradient)
C++ to Python binding
C++ to Python binding

Again, be aware that none of the solutions are optimal, but it proves the main point. And yes, the language is still in early stages and lacks some notable features. But the Mojo roadmap is for sure something to look forward to. Mojo, you got me hooked. 🔥


THE GRADIENT DESCENT ALGORITHM


Digging a little deeper inside the gradient descent algorithm will explain why it is possible to calculate the gradient in parallel. 


So the goal is to figure out the coordinates Xpoint = (x, y, z), but you are only able to measure the distance between each point individually. These measurements will get you a distance matrix which will look something like this. Here, the distance between a point and itself is always 0. And the distance between
point 1 and point 2 is the corresponding D12 .

Figuring out all the 3D coordinates of all points X can be solved by simply estimating their position, calculating the difference between each pair and comparing it with your measurement. A new estimate of X can be calculated using gradient descent. Taking a step in the negative gradient direction of the loss function.

The loss function will look something like this:

Where |Xi - Xj| is the estimated distance between 2 points and Dij the measured distance between each. So minimizing loss means minimizing the squared difference between your estimate and your measurement. 


The gradient can be calculated by automatic differentiation (as PyTorch and Tensorflow do) or by symbolically differentiating the loss function. I’m choosing for symbolic differentiation here as this reduces the time to implement this example. But already looking on how to try out autodiff as well. 


Looking at only one row at the time, each row of the gradient can be calculated independently:


So the complete gradient matrix will be the following. 

Note that because we can calculate each row of the gradient  independently this allows for parallel computation of the gradient. Which does play a crucial role in execution speed.

Recognition:


I build this together with my friend Cedric Van Heck, who helped a lot in conceptualizing the post and implementing the JAX implementation. One can achieve so much more when multiple eyes are looking at the same problem.