Smoking Hot Binary Search In Zig

published: [nandalism home] (dark light)

Inspiration

Beautiful Binary Search in D. I was inspired by this recent post on HN to see if I could do the same thing in zig.

Briefly, a binary search algorithm is unrolled into a linear sequence of if-statements which find a key in a sorted array by repeatedly splitting ranges in half. This will only work if we know the array size in advance, and we will end up with log2(nelems) if-statements.

// pasted verbatim from the blog post (https://muscar.eu/shar-binary-search-meta.html)
int bsearch1000(int[1001] xs, int x)
{
  auto i = 0;
  if (xs[512] <= x) { i = 1000 - 512 + 1; }
  if (xs[i + 256] <= x) { i += 256; }
  if (xs[i + 128] <= x) { i += 128; }
  if (xs[i + 64] <= x) { i += 64; }
  if (xs[i + 32] <= x) { i += 32; }
  if (xs[i + 16] <= x) { i += 16; }
  if (xs[i + 8] <= x) { i += 8; }
  if (xs[i + 4] <= x) { i += 4; }
  if (xs[i + 2] <= x) { i += 2; }
  if (xs[i + 1] <= x) { i += 1; }

  if (xs[i] == x) return i;
  else return 0;
}

The main result of the linked post is a D-lang macro which generates code like that above, given any N (e.g. N=1000 above).

The Zig Macro

I developed the equivalent macro in Zig (mainly using comptime inline-while).

As usual in zig the macro and code are mixed together in a single function. The function takes 2 compile time constant parameters (marked with comptime). These are the array size (which must be known at compile time) and Elem, the array element type.

We do some math on the array size to work out how many if-statements we must generate and what array ranges we must produce depending on the runtime outcome of the if-statement comparisons.

The Elem type is just so we can use the search on different types (note: the Elem type must be comparable with '<' which precludes e.g. string, for that we'd need to pass in a compare function instead. I'll just stick with the example given in the original blog post, linked above).

The other 2 parameters are the runtime parameters, the array itself and the key we are searching for. (I've used the same variable names as the original post for easier comparison).

fn bsearch(comptime N: usize, comptime Elem: type, xs: [N+1]Elem, x: Elem) usize{
  const k = comptime math.log2(N-1);
  var p: usize = 0;
  const k2 = comptime try math.powi(usize, 2, k);
  if(xs[k2] <= x) p = (N - 1) - k2 + 1;
  comptime var rank = k2;
  comptime var i = k; inline while(i>0) : (i-=1) {
    rank /= 2;
    if(xs[p + rank] <= x) p += rank;
  }
  return if(xs[p]==x) p else 0;
}

The binary search code is broken down like this:

(the original algorithm is 1-based, hence the return 0 and 1-offsetting everywhere. It would be nice to change that but I just slavishly, mindlessly followed the original.)

Testing

zig has a nice inline testing feature. So I added a test for the 1000 element case shown in the original blog post. This confirms that the code works for all possible array sizes ...

This is the interesting line, where bsearch is called,

    const i = bsearch(1000, isize, tbl, key);

(In the actual code below I've used constants for the size and type. Zig treats const'ants the same as literal values for the purpose of comptime.)

The test does the following:

$ zig test bbinsearch.zig
All 1 tests passed.

Yay! It works!

The full test

test {
  const tt = std.testing;
  var rand_impl = std.rand.DefaultPrng.init(42);
  const tbl_len = 1000;
  var tbl: [tbl_len+1]isize = undefined;

  const Elem = isize;
  var t: usize = 0; while(t<tbl_len) : (t+=1) { tbl[t]=rand_impl.random().int(Elem); }
  std.sort.sort(Elem, &tbl, {}, std.sort.asc(Elem));

  var j: usize=0; while(j<tbl_len): (j+=1){
    const key = tbl[j];
    const i = bsearch(tbl_len, Elem, tbl, key);
    try tt.expectEqual(key, tbl[i]);
  }
}

site built using mf technology