A colleague referred me to an interesting coding brainteaser. The problem statement was (paraphrased) roughly as follows:
Given a list of integers, produce an output list which consists of the product of every integer in the input list except the one at the corresponding index.
For example, given the list [1, 2, 3, 4, 5], the output should be [120, 60, 40, 30, 24].
As a followup, what if you can't use division?
Let's start with the obvious brute force approach, multiplying over
the whole list for each element (example code in Rust).
fn brute_force(input: &Vec<i64>) -> Vec<i64> {
input.iter().enumerate().map(|(i, _)| {
input.iter().enumerate().filter_map(|(j, x)| {
if i != j {Some(x)} else {None}
}).product()
}).collect()
}
This is straightforwardly O(n²)—we have two levels of
nested iteration over the input.
The “as a followup” actually gives away another
interesting solution. We can “un-multiply” with the
division operator; so we compute the product of the whole input,
and then for each element divide that element out.
fn division(input: &Vec<i64>) -> Vec<i64> {
let prod: i64 = input.iter().product();
input.iter().map(|x| prod / x).collect()
}
Taking the initial product is O(n) and then dividing it out of
each input element is another O(n), so the whole thing is
O(n).
So, if we don't use division, can we do better than the O(n²) brute force method? Obviously because I'm still writing about it we can. But how?
The trick here is
dynamic
programming, one of the most hilariously badly-named concepts
in computer science. (Seriously, read the “history”
section of that Wikipedia article.) If we can break a problem down
into subproblems, some of which are repeated, we may be able to
store the subproblem results to avoid recomputing them (this is
called “memoization”). Let's take a look at the
factors of the output elements from the input list [2, 3, 5,
7, 11] (all prime factors to highlight shared subproducts):
1155: [3, 5, 7, 11]
770: [2, 5, 7, 11]
462: [2, 3, 7, 11]
330: [2, 3, 5, 11]
210: [2, 3, 5, 7]
That's a lot of overlap, even for a toy example! For example, 2
and 3 are factors of three out of the five outputs, and so 2 * 3
should only need to be calculated once.
This suggests that we can compute a list of partial products up to
a given element; so our all-prime sample input would give
[ 1,
1 * 2 = 2,
2 * 3 = 6,
6 * 5 = 30,
30 * 7 = 210 ]
We can perform a symmetrical process to get the partial products
down to a given element, going over the input list in
reverse; then multiply the corresponding up-to and down-to
subproducts to get the final result. Putting it all together, our
code looks like
fn subproducts(input: &Vec<i64>) -> Vec<i64> {
let mut subproduct = 1;
let up_to:Vec<i64> = input.iter()
.map(|x| { let val = subproduct; subproduct *= x; val })
.collect();
subproduct = 1;
let down_to:Vec<i64> = input.iter().rev()
.map(|x| { let val = subproduct; subproduct *= x; val })
.collect();
up_to.iter().zip(down_to.iter().rev()).map(|(l, r)| l * r).collect()
}
Computing up_to is O(n), since we iterate over the the
input once; down_to is also O(n), and the final
zip-and-multiply is O(n), so the whole thing is actually only
O(n). Nice! I haven't benchmarked these, but I suspect this
algorithm is actually faster in practice than the
division one; it performs roughly three times as many
multiplications, but division is much slower than
multiplication on modern processors—generally by much more
than the factor of two which would make the algorithms roughly
comparable.
Note that these order analyses gloss over a very important fact:
multiplication actually isn't constant time. The best algorithms
are marginally worse than O(log n) in the size of the input (worse
than O(n) in the number of digits). This usually doesn't
matter because most of the numbers we deal with can fit in a
reasonable constant-size type (the i64
I used in the
code samples goes up to almost 1019), but because we're
repeatedly multiplying here, we can get very large numbers in our
output. (This is also why I'm not going to try to benchmark: any
input large enough for runtime to exceed measurement error will
also overflow my number type.)