Working up from the foundations

 

Did you ever wonder how many dependencies we consume to deliver a modern application? Trust us it is a
pretty huge number. It saves us time and more importantly spares us the complexity if we had to implement
all of the boilerplate code and trivial algorithms by ourselves. Let us admit in the same paragraph that we
are not experts in all types of coding. Typically, all of the frameworks for neural networks carry the following
key considerations –
1. Operating on arrays or floating upwards we can say tensors.
2. Techniques to leverage GPU and TPU with pipeline jobs and functions that target the availability of
GPU and TPUs.
3. Automatic differentiation or higher-order gradients

4. Higher-level functions that build and train the entire neural network (read this as management or
automation depending on your vantage point)
Did you wonder where are we driving you with this? If you did let us raise the curtain and call for JAX. This
is a new framework that claims to be fast and furious with mathematics. Built using the Python language;
this framework solves similar problems as NumPy. If you ask were there some discontent with NumPy, the
immediate answer will be no. But if you let that question gestate in your head you will notice the answer tilt
towards probably yes. This tilt is due to the growing number of GPU and TPU availability. Many of the
modern frameworks create pipelines and jobs that utilise those processing units but the granular
calculations which are typically a step in those jobs/pipelines are still performed in the CPU.
JAX

Enter JAX they move these fundamental and granular calculations to GPU and TPU when available. It falls
back to CPU gracefully when those processing units are not available. Thus, minimising the need for
orchestration in frameworks and making them less complex. Amongst the key characteristics JAX focuses
essentially on the point of automatic differentiation or gradient determination. It accelerates developers with
repertoire of purely functional API to address the rest of characteristics of developing a neural network. This
is a dependency that you will want to import going forward and also on existing code. Here is a quick
example of how to convert an existing NumPy array to a JAX array (we will do another curtain raiser soon)
import numpy as np
import jax.numpy as jnp
x = np.random.rand(6000,6000)
#Now let us convert to this to JAX array like this –
j = jnp.array(x)
print("numpy – " + type(x) + "jax – " + type(j))
You will notice that the type of the JAX's array is termed as DeviceArray. Whereas the NumPy one is
simply a ndarray. There is a change in the name but internally there is more to it. DeviceArray is built
on XLA which stands for Accelerated Linear Algebra . This is a module that is built within TensorFlow to
accelerate models built using TensorFlow without code change. DeviceArray in terms of storage is similar
to the ndarray but different in the way it is operated upon. The first of difference is the
unlike ndarray, DeviceArray is immutable. Once assigned a value it cannot be changed unless we call
upon the API of JAX to perform the modification. Let us put that as an example –
# This is possible with numpy
x[0,0] = -10
print(x)
You will receive something similar as output.
[[-10. 0.07 …]

[…]]
With JAX since arrays are immutable you will have to do something like this –
y.at[0,0].set(10)
Underneath this operation in the memory, JAX creates a copy of the array instead of modifying the existing
ones as NumPy does. Now, the question is it beneficial – The answer is apparently it depends. The benefit
of immutable types is not realized by this ability to modify. Rather their abilities are best demonstrated by
the ability to create and discard large-sized n-dimensional arrays quickly. By using the example above, we
demonstrate the finer print of how to achieve something that you have been doing with NumPy already.

What we want to convey is simple – JAX is a drop-in enhancement to NumPy and not replacement. Why
NumPy it is in fact a composable function transformation for Python+NumPy code. Like with any library the
best way to capitalize on the benefits of the framework is by wearing the hat of the designer who designed
the framework.
JAX style of thinking
JAX is NumPy friendly with similar interface and method names. However, it is layered. jax.numpy is built
on jax.lax which in turn compiles to instructions of XLA. The NumPy interface for JAX is forgiving
whereas the jax.lax is more demanding and stickler to rules of coding. So,
jnp.add(4,4.67)
# will work with jax.numpy
The above code will convert the 4 to a float32 implicitly and forgive the developer for even attempting a
mixed type of addition. If you had tried this with jax.lax –
import jax.lax as la
la.add(4,4.67)
This will yield a TypeError immediately after execution. For an experienced developer there is more
control with the lax module over the numpy for a starter or a developer migrating from the NumPy library
the jax.numpy is a welcoming module. Thus, with the knowledge of this layering you could target the
appropriate level in your code to have an appropriate level of control.
The next one and most important one that screams developer for adoption is the Just In Time (JIT) feature.
For developers bridging into the data engineering and science world from the C# and Java world, this word
will be familiar and probably be surprised that it was not there for so long. Google by introducing JIT into
JAX have moved a cornerstone feature into data language and engineering. A code that let us say adds 0.2
to every element in the array for 100 times using the least efficient approach and thus is a killer of
performant code.
def shift(a):
for i in range(100):
a += 0.2+a
return a
Let us use the timeit module to see how the code snippet performs without JIT and later with JIT.
%timeit shift(j)
Depending on the workload at any time and your machine's configuration you might see a different value. In
the machine where we were running the Jupyter notebook, we noticed it took in the order of milliseconds.
Now we introduce JIT.
from jax import jit
jshift = jit(shift)
%timeit jshift(j)
The result by value was the same but the operation was completed in the order of microseconds. That is
10 3 order of performance boost. This hints to us that operations that are costly could be further optimised
with the JIT compilation of the functions. As you get excited by that 10^3 magnitude of performance boost.
Do keep in mind JIT is a sword that kills all the performance-hogging monsters. It works best when the size
of the array is known upfront. That loops us back to the first point on array mutability that we tabled. Thus,
JAX style thinking requires you structure; in better words; design the problem in a manner where many
capabilities of JAX can be leveraged to deliver an amazing framework or neural network you are building.