Smoking Hot Binary Search In Zig
published: [nandalism home] (dark light)
Inspiration
Beautiful Binary Search in D. I was inspired by this recent post on HN to see if I could do the same thing in zig.
Briefly, a binary search algorithm is unrolled into
a linear sequence of if-statements which find a key in a sorted array by repeatedly splitting ranges in half.
This will only work if we know the array size in advance, and we will end up with log2(nelems)
if-statements.
// pasted verbatim from the blog post (https://muscar.eu/shar-binary-search-meta.html) int bsearch1000(int[1001] xs, int x) { auto i = 0; if (xs[512] <= x) { i = 1000 - 512 + 1; } if (xs[i + 256] <= x) { i += 256; } if (xs[i + 128] <= x) { i += 128; } if (xs[i + 64] <= x) { i += 64; } if (xs[i + 32] <= x) { i += 32; } if (xs[i + 16] <= x) { i += 16; } if (xs[i + 8] <= x) { i += 8; } if (xs[i + 4] <= x) { i += 4; } if (xs[i + 2] <= x) { i += 2; } if (xs[i + 1] <= x) { i += 1; } if (xs[i] == x) return i; else return 0; }
The main result of the linked post is a D-lang macro which generates code like that above, given any N (e.g. N=1000 above).
The Zig Macro
I developed the equivalent macro in Zig (mainly using comptime inline-while).
As usual in zig the macro and code are mixed together in a single function. The function takes 2 compile time constant parameters
(marked with comptime
). These are the array size (which must be known at compile time) and Elem, the array element type.
We do some math on the array size to work out how many if-statements we must generate and what array ranges we must produce depending on the runtime outcome of the if-statement comparisons.
The Elem type is just so we can use the search on different types (note: the Elem type must be comparable with '<' which precludes e.g. string, for that we'd need to pass in a compare function instead. I'll just stick with the example given in the original blog post, linked above).
The other 2 parameters are the runtime parameters, the array itself and the key we are searching for. (I've used the same variable names as the original post for easier comparison).
fn bsearch(comptime N: usize, comptime Elem: type, xs: [N+1]Elem, x: Elem) usize{ const k = comptime math.log2(N-1); var p: usize = 0; const k2 = comptime try math.powi(usize, 2, k); if(xs[k2] <= x) p = (N - 1) - k2 + 1; comptime var rank = k2; comptime var i = k; inline while(i>0) : (i-=1) { rank /= 2; if(xs[p + rank] <= x) p += rank; } return if(xs[p]==x) p else 0; }
The binary search code is broken down like this:
- comptime: calculate the base 2 log() of the array size and then get 2^k again (integer math)
- generate a single 'set up' if-statement which doesn't follow the regular pattern. Parts of this are comptime i.e. k2 and N, but the if-statement will run at runtime.
- comptime: an inline-while loop runs at compile time to generate 'k' if-statements. These include the comptime indexes generated by repeatedly halving k2 (i.e. 2^k).
- finally we should have p=key-index or else key not found, we check that case and return 0 if not found.
(the original algorithm is 1-based, hence the return 0 and 1-offsetting everywhere. It would be nice to change that but I just slavishly, mindlessly followed the original.)
Testing
zig has a nice inline testing feature. So I added a test for the 1000 element case shown in the original blog post. This confirms that the code works for all possible array sizes ...
This is the interesting line, where bsearch is called,
const i = bsearch(1000, isize, tbl, key);
(In the actual code below I've used constants for the size and type. Zig treats const'ants the same as literal values for the purpose of comptime.)
The test does the following:
- generates an array of 1000 random integers (isize type in zig is i64).
- sorts the table (required for binary search).
- loops over the entire table, taking each element in turn and then searching for it, the hope being that we get the same index back from bsearch().
$ zig test bbinsearch.zig All 1 tests passed.
Yay! It works!
The full test
test { const tt = std.testing; var rand_impl = std.rand.DefaultPrng.init(42); const tbl_len = 1000; var tbl: [tbl_len+1]isize = undefined; const Elem = isize; var t: usize = 0; while(t<tbl_len) : (t+=1) { tbl[t]=rand_impl.random().int(Elem); } std.sort.sort(Elem, &tbl, {}, std.sort.asc(Elem)); var j: usize=0; while(j<tbl_len): (j+=1){ const key = tbl[j]; const i = bsearch(tbl_len, Elem, tbl, key); try tt.expectEqual(key, tbl[i]); } }