Multi-linear interpolations in Julia: Part 1
20 June 2024
Linear interpolations, sometimes lerps, allow us to smoothly find values between two data points in a way that joins the data together. Multi-linear interpolations are the same thing applied to data points that lie on a grid. The simplest case is to consider an X, Y grid where we have four data points, each with a value Z, and we want to find the value of Z smoothly between the known data points.
The simplest way of doing this is to decompose the interpolation along axes. That is, we first interpolate between two points along the Y axis, and then, using the result, along the X axis.
Putting together some example data:
# assemble a square grid (a Matrix in this case)
grid = reshape([
SVector(0.0, 0.0), SVector(1.0, 0.0),
SVector(0.0, 1.0), SVector(1.0, 1.0),
], (2,2))
# assign a value to each grid point
values = reshape([0.0, 1.0, -1.0, 2.0], (2, 2))
Now, for some arbitrary point, say (0.5, 0.3)
we want to know what the interpolated value is. We'll assume the grid is regular (i.e. that the change in X or Y between any two grid positions is a constant).
function interpolate_2d(grid::AbstractArray, values::AbstractArray, x)
x1 = @views findfirst(i -> i[1] > x[1], grid[:, 1])
y1 = @views findfirst(i -> i[2] > x[2], grid[1, :])
# check for out of bounds
i1 = isnothing(x1) ? size(grid, 1) : x1
i2 = isnothing(y1) ? size(grid, 1) : y1
p1, p2 = grid[i1-1, i2-1], grid[i1, i2]
# calculate interpolation weights
weight = @. clamp((x - p1) / (p2 - p1), 0, 1)
Xz1 = values[i1-1, i2-1] * (1 - weight[1]) + weight[1] * values[i1, i2-1]
Xz2 = values[i1-1, i2] * (1 - weight[1]) + weight[1] * values[i1, i2]
# interpolate over y
Xz1 * (1 - weight[2]) + Xz2 * weight[2]
end
# invoking our function
interpolate_2d(grid, values, SVector(0.5, 0.3))
# 0.5
The above implementation serves as an example implementation, but has limitations. The check to see if findfirst
returns nothing
imposes a boundary condition, namely that if x
is greater than the grid domain, the clamp
on the weight ensures we return the last known data value. That is, we do not extrapolate the interpolation. The other limitation is that we do not do a similar check for the lower grid boundary, i.e. asking for the interpolation below with a coordinate below 0
will result in an error.
The problem
I need to be able to handle N-dimensional multi-linear interpolations (where N is reasonably small, < 7 or so) that can be used in a numerical optimization problem. This introduces two additional criteria the interpolations must meet:
- No allocations: allocating in tight loops is inviting the Devil (the GC) for dinner
- Automatic-differentiation enabled: make optimization algorithms go fast
Finally, I want the interpolation algorithms to work on arbitrary data structures. If I want to interpolate Float64
or some custom struct, I want to be able to reuse much of the solution, and only implement how a single dimensional linear interpolation works for my data type. This also gives downstream users the ability to plug in their custom data types and still be able to interpolate.
As with anything performance focused we should measure what the performance of this function is to give us a target:
using BenchmarkTools
@btime interpolate_2d($grid, $values, SVector(0.5, 0.3))
# 4.945 ns (0 allocations: 0 bytes)
That's pretty fast. Let's see how we go.
Iterating towards a solution
Before we begin lets just change how we represent the data. Since we are enforcing a regular grid, instead of storing the coordinate of each grid point, we can store on the axes. The above example becomes:
X = [0.0, 1.0]
Y = [0.0, 1.0]
grid = (X, Y)
# if we have N many points associated with each grid point
# the values array has shape (N, 2, 2)
values = reshape([0.0, 1.0, -1.0, 2.0], (1, 2, 2))
The obvious thing to try is to cache each reduction in dimension. That is, to interpolate 2D data as above, we first had to reduce to a 1D problem. For a 3D interpolation, we first reduce to a 2D. That means we need a place to cache N-1
dimensions worth of intermediary points.
Let's start with the naive approach:
struct InterpolationCache{D,T,N}
cache::Array{T,N}
function InterpolationCache{D}(values::AbstractArray{T,N}) where {D,N,T}
cache::Array{T,N-1} = zeros(T, size(values)[1:N-1])
new{D,T,N-1}(cache)
end
end
The cache is an N
dimensional vector, where the dimension is one lower than the dimension of the data. For example
InterpolationCache{2}(values)
# InterpolationCache{2, Float64, 1}([0.0, 0.0])
We see in the cache we have a vector with two values, one for each of the grid points we need for the lerp.
Indexing the values cache in a way that feels natural can be done with a few @generated
functions. These are functions that we can build during the Julia compiler's inference step, and lets us assemble the function body based on the types:
@generated function _get_all_slice(values::AbstractArray{T,N}, i) where {T,N}
rem = [:(:) for _ in 1:N - 1]
:(@views values[$(rem...), i])
end
@generated function _get_dim_slice(values::AbstractArray{T,N}, ::Val{M}) where {T,N,M}
rem = [:(:) for _ in 1:(N - M)]
inds = [:(1) for i in 1:M]
:(@views values[$(rem...), $(inds...)])
end
The first gives us a view into the i
th slice. So for two dimensional data, this is values[:, i]
, whereas for three dimensional this is values[:, :, i]
. The second gives us a way to access a reduced dimension slice. Again, this maps for two and three dimensions roughly to values[:, 1]
and values[:, 1, 1]
for M = 1
and M = 2
respectively.
We scatter @views
everywhere so that we only ever peer into memory and never copy. We can use these functions in much the same way as getindex
.
Lets go through how they will be used:
An inplace non-extrapolating 1D interpolation can now be written and uses _get_all_slice
:
function _inplace_interpolate!(
out,
grid::AbstractArray,
values::AbstractArray,
x
)
i2 = searchsortedfirst(grid, x)
# edge case checks
if (i2 == 1)
@. out = values[1]
return out
end
if i2 > lastindex(grid) || grid[i2] > grid[end]
@. out = values[end]
return out
end
i1 = i2 - 1
x1 = grid[i1]
x2 = grid[i2]
# interpolation weight
weight = (x - x1) / (x2 - x1)
y1 = _get_all_slice(values, i1)
y2 = _get_all_slice(values, i2)
@. out = y1 * (1 - weight) + y2 * weight
out
end
If the previous interpolation function made sense then hopefully this one will too. The key difference now is that we have addressed both the upper and lower boundaries, and that we write the output into some out buffer instead of returning the result.
Now, the way we construct our N-1 dimensional cache is by just taking smaller and smaller dimensions of our pre-allocated array. For this we use _get_dim_slice
:
function _make_cache_slices(cache::InterpolationCache{D}) where {D}
itr = ((Val{i}() for i in 0:D-1)...,)
map(itr) do i
_get_dim_slice(cache.cache, i)
end
end
There's a bit of trickery going on here to make this type stable. In particular, we build the Val
objects whilst splatting the tuple, since this then is easier for the compiler to constant fold.
Since _get_dim_slice
always indexes the lower dimension at index 1
we can write the output of the previous step in the same buffer, avoiding redundant allocations.
Now we get the meat of the operation:
function interpolate!(cache::InterpolationCache{D},
grids::NTuple{D,<:AbstractArray}, values::AbstractArray, x::NTuple{D}) where
{D}
itr = (1:D...,)
slices = _make_cache_slices(cache)
vs = (values, slices...)
for K in zip(slices, vs, itr)
c, v, i = K
_inplace_interpolate!(c, grids[i], v, x[i])
end
slices[D]
end
This should hopefully be pretty straight forward. We simply loop over each dimension, do the interpolation to build the reduced dimension problem, and continue. The output is then written into the D
th cache view.
cache = InterpolationCache{2}(values)
interpolate!(cache, grid, values, (0.3, 0.5))
# 1-element view(::Matrix{Float64}, :, 1) with eltype Float64:
# 0.4999999999999999
But we are still allocating despite this cache:
BenchmarkTools.Trial: 10000 samples with 986 evaluations.
Range (min … max): 52.279 ns … 149.027 μs ┊ GC (min … max): 0.00% … 99.89%
Time (median): 61.093 ns ┊ GC (median): 0.00%
Time (mean ± σ): 114.479 ns ± 1.504 μs ┊ GC (mean ± σ): 22.91% ± 5.47%
▂██▆▃▃▂▁▁ ▁▂ ▃▆▇▅▂▁ ▁▁▁ ▁ ▂
▆█████████▇▆▅▅▆█▇▅▆▇▆▆▇▇▇▆▆███▆▅▆▅▃▇███████▇▇████████▇▆▇▅▅▅▅▅ █
52.3 ns Histogram: log(frequency) by time 174 ns <
Memory estimate: 336 bytes, allocs estimate: 5.
This is because interpolate!
is dramatically type unstable. The variables in the for loop are boxed, which means the call to _inplace_interpolate!
cannot easily be optimized.
Something quite simple we can do is unroll the for loop in the interpolate!
function to encourage type stability.
We replace the for loop with
_unroll_for(Val{D}(), (zip(slices, vs, itr)...,)) do K
...
end
Where we can implement this loop unroll with a handy function to keep around:
"""
_unroll_for(f, ::Val{N}, iter)
Unroll a for loop when there are a compile-time known number of items (i.e.
`length(iter) == N`).
"""
@generated function _unroll_for(f, ::Val{N}, iter) where {N}
exprs = [:(f(iter[$i])) for i = 1:N]
quote
$(exprs...)
end
end
And now:
BenchmarkTools.Trial: 10000 samples with 997 evaluations.
Range (min … max): 21.449 ns … 59.831 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 22.461 ns ┊ GC (median): 0.00%
Time (mean ± σ): 22.777 ns ± 1.176 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▅█▇▇▅▂▁▇▇▁ ▃▂▁▄▂▁
▃▇███████████████████▆▆▄▄▄▅▄▃▃▄▄▃▃▂▂▂▁▁▁▂▂▂▂▃▃▃▅▄▄▃▂▂▁▁▁▁▁▁ ▄
21.4 ns Histogram: frequency by time 26 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
Excellent, we see a 2x performance increase and the allocations vanish.
Next steps
The cache now works, and we can more or less make this work for arbitrary user data by having the user define their own zero
function for building the cache, and their own _inplace_interpolate!
function. But that's a big ask. Furthermore, the cache at the moment does not support AD.
In the next part we'll try expanding the cache to work with dual numbers, and provide a friendlier abstraction for wrapping data types.
Later we'll look at trying to optimize performance to see if we can get closer to the basic 2D implementation. This will require a 4x performance improvement.