January 05, 2014

Branch Prediction and Why Clever Isn't Always Better

Recently I was implementing a research prototype for locality-sensitive hashing using Python with NumPy. As part of this work, I had a need for a fairly simple function to create a bitmask. The one catch, however, is that the function would need to be called billions of times in a loop. It quickly became apparent that a Python implementation would be far too slow (and why should it be fast, Python isn’t designed for bit twiddling). Fortunately, it’s pretty easy to extend Python (and NumPy!) by writing a native library (more on that in another post). So I set off to write this function in C.

get_mask

The function, get_mask, would take a bitfield represented as an n-bit integer, index, and an n-length array of integers, order, as arguments. The order array represents a permutation of the bitfield represented by the input integer, index. The output of the function would be the index’s bits shuffled by the order specified in order. For example:

Imagine in this toy example, our integers are only four bits (for illustration purposes). We call get_mask(5, [1,3,2,0]) and return 6. To explain, we look at the binary representation of index, in this case 0b0101. A new integer, out, is initialized to zero to represent our output.

Since the zero'th and second bits of index are 1 (where bit 0 represents the LSB), we look at the zero'th and second elements in the array order. The values in order, then, are which bits in the output to turn on. We can then say that the output = (1 << order[0]) | (1 << order[2]).

The algorithm, in C, follows:

unsigned long long get_mask(unsigned long long index, int* order) {
  unsigned long long out = 0;
  int j = 0;
  while (index > 0) {
    if (index & 0x1) {
      out |= 1 << order[j];
    }
    j += 1;
    index >>= 1;
  }
  return out;
}

Analysis

Running the above C implementation of get_mask produced marked improvements over the Python implementation. However, profiling revealed a significant amount of time was still being spent in the execution of get_mask. It seemed likely we could do better.

When I wrote the get_mask, I thought if (index & 1) would be clever. Anytime index & 1 evaluated false, order+j would not have to be computed, nor dereferenced (1 instruction), no bit shift would occur (1 instruction), and no logical or would occur (1 instruction). For a savings of somewhere on the order of 3 clocks every time index & 1 evaluated false. If, on average, half of the bits in a 64-bit integer are 0, this conditional should save about 96 clock cycles per call.

Unfortunately, I failed to take into account the branch predictor and instruction pipelining. To keep their pipelines full, CPUs fetch instructions from the near future of the computation, which requires guessing which way conditionals are going to branch. If the guess is incorrect, the CPU must dump out its pipeline and start over at the place where it guessed wrong. On recent Intel chips, the penalty is about 15 cycles.

In the for loop above, it’s essentially impossible to reliably predict how the conditional will be evaluated. The end result is that the conditional is mispredicted (average case) about 50% of the time. When dealing with 64-bit integers, the loop carries out 64 iterations, causing the pipeline to be flushed 32 times and costing 15 clock cycles each time or, on average, about 480 cycles.

Remember, we calculated earlier that not executing out |= 1 << order[j] when not needed would save 96 cycles. It turns out that the branch mispredictions end up costing nearly 400 more cycles than the conditional saves!

Fortunately, it is easy to conceive an equivalent function that has no unpredictable conditional inside its loop: if (index & 0x1) out |= 1 << order[j]; is replaced with out |= (index & 1) << order[j]. When index & 1 is 0, a zero is shifted and or’d, which will have no effect on the output. Exactly as we wanted!

Fixed:

unsigned long long get_mask(unsigned long long index, int* order) {
  unsigned long long out = 0;
  int j = 0;
  while (index > 0) {
    out |= (index & 1) << order[j++];
    index >>= 1;
  }
  return out;
}

Ultimately, there are more optimizations that can be done. To avoid the function call overhead, we can tell GCC to integrate the function code into the caller’s code by static __attribute__((always_inline)). In my case, the millions of get_mask calls typically occur for the same order and incrementing index values allowing me to use a lookup table (initially generated by the above function) and some other caching strategies as well. But that’s not what this article is about.

The takeaway here should be to always consider any branching in your code carefully: particularly when the branching occurs within a tight loop. If the evaluation of the condition is likely to change during the course of the iteration, there will more than likely be expensive branch mispredictions causing the CPU to rebuild its pipeline. Try to avoid this whenever possible, even if the code required to do so is “slightly less efficient.”