A Brief Introduction to the Number Theoretic Transform (NTT)

Published:

Introduction

What this is

  • This post will give an introduction to the number theoretic transform (NTT), beginning with a short foray into the Fourier Transform.
  • Author’s Note: The LaTeX symbol for “not equivalent”, \not\equiv, isn’t working under the current LaTeX hack I’m using. I will instead use , which I too dislike greatly.

What this is not

  • A replacement for a rigorous mathematics or computer science course.
  • A fulfilling journey into Fourier series and Fourier transforms.
  • A cat.

Fourier acronym soup: What are the DFT and FFT?

A transform

In general, a transformation is a function that maps a set to itself. In other words, this can be written as . Transformations come in many forms. For example, you may be familiar with geometric transformations, such as a translation:

Or, a rotation:

Today, we’ll be discussing specifically integral transforms, which, as you may guess, maps a function from its original function space to another function space using integration. This can be written mathematically as the following:

We are applying the transform to , which outputs another function . This may also be referred to as the Fredholm equation of the first kind, if you really want to impress nobody your friendly neighborhood mathematician 1. is usually referred to as the kernel of the integral transform.

A (continuous) Fourier Transform

A Fourier transform decomposes functions dependent on space or time into new functions that instead depend on spatial or temporal frequency 2. It is a generalization of the Fourier series, which is a way to represent a periodic function as the sum of sine and cosine functions. In that regard, it is similar to the Taylor series, which instead uses monomial terms as seen below for :

For more examples and visualizations of the Fourier Series/Transform, I will direct you to the Fourier Series Wikipedia page, and 3Blue1Brown’s “But what is the Fourier Transform?”.

Mathematically speaking, a Fourier transform is a special case of the integral transform. Here, we set:

, and . Which gives us:

.

A discrete Fourier Transform

Finally, we arrive at the Discrete Fourier Transform (DFT). Recall that discrete is usually used to contrast the term “continuous.” Instead of looking the interval , we’ll pick complex numbers, i.e. a finite set of instances. From our original equation,

We can rewrite it as:

By picking a set of N finite points , we have:

.

Since we have chosen from a finite set of data points, the data will be treated as if it were periodic. This means that is equivalent to or is the same as .

Aside: What is the primitive th root of unity?

We say a number is an th root of unity if . A number is a primitive th root of unity if it is a root of unity and is the smallest integer in for which . It is standard for DFT to use . Indeed, , and does not equal for .


Finally, this can be written more succinctly as:

, for .

Using the fundamental frequency, one can then set

Why bring this up? The NTT is a generalization of this DFT. Understanding this will provide some insight into why we can model NTT algorithms directly from DFT algorithms.

Code: Naive DFT
import math


def naive_dft(a):
    n = len(a)
    out = [0] * n
    
    def cexp(x):
        return complex(math.cos(x), math.sin(x))

    for i in range(n):
        for j in range(n):
            omega = cexp(-2 * math.pi * i * j / n)
            out[i] = out[i] + a[j] * omega

    return out

A faster Fourier Transform

The DFT formula requires operations. There are exactly outputs , each of which requires summing terms. FFT algorithms compute the same result in operations. The classic FFT is the Cooley-Tukey algorithm, which uses a divide-and-conquer approach, recursively decomposes the DFT of size into smaller DFTs and . These are then multiplied by the complex roots of unity, also known as twiddle factors 3. Interestingly enough, this algorithm was first devised by Carl Friedrich Gauss 160 years before Cooley and Tukey independently rediscovered it in the 1960s 4.

I highly recommend checking out the original paper titled “An Algorithm for the Machine Calculation of Complex Fourier Series” (1965) by Cooley and Tukey as well. It is relatively short, and provides remarkable insight into their thought process 7. Below is an iterative version of the algorithm using decimation in time (splitting into sums over even and odd indices).

Code: Cooley-Tukey FFT
import math


def reverse_bits(number, bit_length):
    # Reverses the bits of `number` up to `bit_length`.
    reversed = 0
    for i in range(0, bit_length):
        if (number >> i) & 1: 
            reversed |= 1 << (bit_length - 1 - i)
    return reversed

def cexp(x):
    return complex(math.cos(x), math.sin(x))

def cooley_tukey_fft(a):
    # Radix-2 decimation-in-time FFT.
    n = len(a)
    out = a

    for i in range(n):
        rev_i = reverse_bits(i, n.bit_length() - 1)
        if rev_i > i:
            out[i] ^= out[rev_i]
            out[rev_i] ^= out[i]
            out[i] ^= out[rev_i]

    log2n = math.log2(n)
    # The length of the input array 
    # `a` should be a power of 2.
    assert log2n.is_integer()
    
    iterations = int(log2n)
    M = 2
    for _ in range(iterations):
        for i in range(0, n, M):
            g = 0
            for j in range(0, M >> 1):
                k = i + j + (M >> 1)
                U = out[i + j]
                V = out[k] * cexp(-2 * math.pi * g / n)
                out[i + j] = U + V
                out[k] = U - V
                g = g + n // M
        M <<= 1

    return out

Since we’re just taking the DFT and dividing it into smaller portions, it is easily applicable to other generalized forms of the DFT. This includes the NTT, which we’ll see later on. One interesting fact I learned while reading about this is that there is no proven lower bound for the FFT’s flop count 5.

What is the number theoretic transform?

As briefly mentioned before, the number theoretic transform (NTT) is a generalization of the discrete Fourier transform (DFT) that uses the th primitive root of unity based upon a quotient ring instead of a field of complex numbers. Instead of using , we use , where is a generator of the multiplicative group, is a prime number, and is an integer that is guaranteed to exist (you’ll soon see why).

Dirichlet’s pretty cool theorem

The Dirichlet prime number theorem states that for any two positive coprime integers and , there are infinitely many primes of the form , where , the length of the input, is also a positive integer. For NTT, we will set up the equation , where is a prime number, and is an arbitrary positive integer of our choosing. Note that should also be larger than and each value in the input array to avoid overflow.

Aside: What are multiplicative groups?

A multiplicative group of , denoted as , is the set of integers coprime to , but the operation is multiplication . In contrast, consists of elements with addition as the operation. We could multiply the elements of , but we wouldn’t obtain a group. A group must satisfy certain axioms, one of these being “each element must have an inverse to produce the identity element.” For example, does not have a multiplicative inverse, i.e. there does not exist an to satisfy the equation . Instead, we want to confine our attention to those elements which do have multiplicative inverses, or units. We say is a unit in if are coprime. Let’s take an example, say . We can produce a multiplicative table to see more concretely that every product is also a unit.

 1234
11234
22461
33625
44152

Each index is the value . For example, .

Then,

, and

More generally, if is prime, then all positive integers smaller than are relatively prime to , since the greatest common denominator of is for . This means we know the size of will always be . Why’s this important? Well, our goal is to find a generator for the primitive th root of unity, which we’ve defined as .

We need to show first that . We know . We also will use Euler’s theorem, which defines as the number of positive integer up to that are coprime with , so that . Then, for an input of length , we have that:

Second, we need to show that for . Well,

So that , as required. Thus, we’ve shown our generator is indeed a primitive th root of unity.


Finding the primitive nth root of unity

Let’s define , the length of our input, as 4, so that we have the equation . Then, we’ll pick an arbitrary value, say , so that . Great! We now have . Now we can either find a generator from the multiplicative group of , or we can find the primitive root directly. For this example, we’ll take the latter approach. By definition, a primitive th root of unity in holds the following two conditions:

  1. for each prime factor of .

So, we’re looking for an integer such that the conditions above hold. First, we know that , so that our only prime factor . This means we want and . Let . , and . Thus, we now have our primitive th root of unity! Or, in this case, our primitive th root of unity.

Finding a generator of the multiplicative group

Another approach is finding a generator of the multiplicative group and then acquire the primitive th root of unity from this. This is slightly simpler, since we instead look within the finite range . By definition, a generator of must hold the following condition: For each unique prime factor of ,

. Again, this can be seen more clearly with an example. We’ll let so that our unique prime factors of are . Let’s choose an arbitrary value in , say . Then,

Thus, is a generator! From the generator , we can calculate the primitive th root of unity by using the coprime number guaranteed in Dirichlet’s theorem: . The primitive th root of unity would be . In code, we can write a function to pre-compute the necessary twiddle factors for a given array length and prime number . As you’ll see, I am using SymPy to conduct some of the more trivial mathematical computations.

Code: Generator
from sympy.ntheory import isprime, primitive_root


def generate_twiddle_factors(n, q):
    # Produces `n` omegas (or twiddle factors) 
    # given the generator: x^b (mod q) of the 
    # prime field of q.
    assert isprime(q)
    
    x = primitive_root(q)
    
    # Applying Dirichlet's theorem, 
    # we have: q = bn + 1.
    b = (q - 1) // n
    
    omega = (x ** b) % q

    omegas = [1]
    for i in range(n):
        # Multiply (mod q) by the previous value.
        omegas.append((omegas[i] * omega) % q)

    return omegas[:n]  # Drop the last, needless value.

A naive NTT

I mentioned earlier that NTT is simply a generalization of the DFT. In other words, a lot of the algorithm remains the same; we are simply changing our roots of unity.

Code: Naive NTT
def naive_ntt(a, q, omegas):
    n = len(a)
    out = [0] * n
    
    for i in range(n):
        for j in range(n):
            out[i] = (out[i] + a[j] * omegas[(i * j) % n]) % q
    return out

A fast NTT

Similarly, we can apply the principles of FFT algorithms to the number theoretic transform. Below is an iterative Cooley-Tukey version, also Radix-2 DIT.

Code: Cooley-Tukey NTT
import math


def reverse_bits(number, bit_length):
    # Reverses the bits of `number` up to `bit_length`.
    reversed = 0
    for i in range(0, bit_length):
        if (number >> i) & 1: 
            reversed |= 1 << (bit_length - 1 - i)
    return reversed

def cooley_tukey_ntt(a, q, omegas):
    # Radix-2 decimation-in-time FFT.
    n = len(a)
    out = a

    for i in range(n):
        rev_i = reverse_bits(i, n.bit_length() - 1)
        if rev_i > i:
            out[i] ^= out[rev_i]
            out[rev_i] ^= out[i]
            out[i] ^= out[rev_i]

    log2n = math.log2(n)
    # The length of the input array `a` should be a power of 2.
    assert log2n.is_integer()
    
    iterations = int(log2n)
    M = 2
    for _ in range(iterations):
        for i in range(0, n, M):
            g = 0
            for j in range(0, M >> 1):
                k = i + j + (M >> 1)
                U = out[i + j]
                V = out[k] * omegas[g]
                out[i + j] = (U + V) % q
                out[k] = (U - V) % q
                g = g + n // M
        M <<= 1

    return out

Round tripping it

After taking the NTT of an array , we can simply apply the inverse NTT, since: . At a high level, we’re applying the inverses of each . Provided below is an algorithm described in Longa et. al’s “Speeding up the Number Theoretic Transform for Faster Ideal Lattice-Based Cryptography,” which instead uses s as twiddle factors (as well as a few other small optimizations) 6. Here, .

Code: Round trip
import math


def cooley_tukey_ntt_opt(a, n, q, phis):
    """Cooley-Tukey DIT algorithm with an extra optimization.
    We can avoid computing bit reversed order with each call by
    pre-computing the phis in bit-reversed order.
    Requires:
     `phis` are provided in bit-reversed order.
     `n` is a power of two.
     `q` is equivalent to `1 mod 2n`.
    Reference:
       https://www.microsoft.com/en-us/research/wp-content/
       uploads/2016/05/RLWE-1.pdf
    """

    assert q % (2 * n) == 1, f'{q} is not equivalent to 1 mod {2 * n}'
    assert (n & (n - 1) == 0) and n > 0, f'n: {n} is not a power of 2.'

    t = n
    m = 1
    while m < n:
        t >>= 1
        for i in range(0, m):
            j1 = i * (t << 1)
            j2 = j1 + t - 1
            S = phis[m + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t] * S
                a[j] = (U + V) % q
                a[j + t] = (U - V) % q
        m <<= 1
    return a


def gentleman_sande_intt_opt(a, n, q, inv_phis):
    """Gentleman-Sande INTT butterfly algorithm.
    Assumes that inverse phis are stored in bit-reversed order.
    Reference:
       https://www.microsoft.com/en-us/research/wp-content/
       uploads/2016/05/RLWE-1.pdf
    """
    t = 1
    m = n
    while (m > 1):
        j1 = 0
        h = m >> 1
        for i in range(h):
            j2 = j1 + t - 1
            S = inv_phis[h + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t]
                a[j] = (U + V) % q
                a[j + t] = ((U - V) * S) % q
            j1 += (t << 1)
        t <<= 1
        m >>= 1

    shift_n = int(math.log2(n))
    return [(i >> shift_n) % q for i in a]

def get_bit_reversed(c, n, q):
    cc = c.copy()
    for i in range(n):
        rev_i = reverse_bits(i, n.bit_length() - 1)
        if rev_i > i:
            cc[i], cc[rev_i] = cc[rev_i], cc[i]

    return cc


def gen_phis(omegas, q):
    def legendre(x, q):
        return pow(x, (q - 1) // 2, q)

    def tonelli_shanks(x, q):
        # Finds the `sqrt(x) mod q`.
        # Source: https://rosettacode.org/wiki/Tonelli-Shanks_algorithm
        Q = q - 1
        s = 0
        while Q % 2 == 0:
            Q //= 2
            s += 1
        if s == 1:
            return pow(x, (q + 1) // 4, q)
        for z in range(2, q):
            if q - 1 == legendre(z, q):
                break
        c = pow(z, Q, q)
        r = pow(x, (Q + 1) // 2, q)
        t = pow(x, Q, q)
        m = s
        t2 = 0
        while (t - 1) % q != 0:
            t2 = (t * t) % q
            for i in range(1, m):
                if (t2 - 1) % q == 0:
                    break
                t2 = (t2 * t2) % q
            b = pow(c, 1 << (m - i - 1), q)
            r = (r * b) % q
            c = (b * b) % q
            t = (t * c) % q
            m = i
        return r

    return [tonelli_shanks(x, q) for x in omegas]

Conclusion

In this post we discussed the Number Theoretic Transform, starting from the definition of a transformation, working through the Fourier Transform, and eventually touching upon a few mathematical principles which allow the NTT to exist. I would strongly suggest looking at the Project Nayuki blog post, as it provides a deeper mathematical understanding as well as some comprehensive examples. This is where a lot of my NTT-related learning started!

You may also find a variant of the code in Cornell Capra’s repository nttstuff. It is not well-documented, nor is it guaranteed to be correct. It was mostly used as a stepping stone to build an accelerator generator for the NTT pipeline. PRs are certainly welcome.

Please feel free to reach out to me at cpg49 at cornell dot edu with any questions, comments, or concerns!


References

[1]: Porter, F. “Integral Equations”, Revision 051012. Link.

[2]: https://en.wikipedia.org/wiki/Fourier_transform

[3]: W. M. Gentleman and G. Sande. 1966. Fast Fourier Transforms: for fun and profit. In Proceedings of the November 7-10, 1966, fall joint computer conference (AFIPS ‘66 (Fall)). Association for Computing Machinery, New York, NY, USA, 563–578. DOI:https://doi.org/10.1145/1464291.1464352

[4]: Heideman, M.T., Johnson, D.H. & Burrus, C.S. Gauss and the history of the fast Fourier transform. Arch. Hist. Exact Sci. 34, 265–277 (1985). https://doi.org/10.1007/BF00348431

[5]: S. G. Johnson and M. Frigo, “A Modified Split-Radix FFT With Fewer Arithmetic Operations,” in IEEE Transactions on Signal Processing, vol. 55, no. 1, pp. 111-119, Jan. 2007, doi: 10.1109/TSP.2006.882087.

[6]: Longa P., Naehrig M. (2016) Speeding up the Number Theoretic Transform for Faster Ideal Lattice-Based Cryptography. In: Foresti S., Persiano G. (eds) Cryptology and Network Security. CANS 2016. Lecture Notes in Computer Science, vol 10052. Springer, Cham. https://doi.org/10.1007/978-3-319-48965-0_8

[7]: Cooley J.W., Tukey J.W. 1965. “An Algorithm for the Machine Calculation of Complex Fourier Series,” in Math. Comp. 19, 297-301. https://doi.org/10.1090/S0025-5718-1965-0178586-1