Optimization Tips for Finch

It's easy to ask Finch to run the same operation in different ways. However, different approaches have different performance. The right approach really depends on your particular situation. Here's a collection of general approaches that help Finch generate faster code in most cases.

Concordant Iteration

By default, Finch stores arrays in column major order (first index fast). When the storage order of an array in a Finch expression corresponds to the loop order, we call this concordant iteration. For example, the following expression represents a concordant traversal of a sparse matrix, as the outer loops access the higher levels of the tensor tree:

A = Tensor(Dense(SparseList(Element(0.0))), fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)))
s = Scalar(0.0)
@finch for j=_, i=_ ; s[] += A[i, j] end

# output

NamedTuple()

We can investigate the generated code with @finch_code. This code iterates over only the nonzeros in order. If our matrix is m × n with nnz nonzeros, this takes O(n + nnz) time.

@finch_code for j=_, i=_ ; s[] += A[i, j] end

# output

quote
s = (ex.bodies[1]).body.body.lhs.tns.bind
s_val = s.val
A_lvl = (ex.bodies[1]).body.body.rhs.tns.bind.lvl
A_lvl_2 = A_lvl.lvl
A_lvl_ptr = A_lvl_2.ptr
A_lvl_idx = A_lvl_2.idx
A_lvl_2_val = A_lvl_2.lvl.val
for j_3 = 1:A_lvl.shape
A_lvl_q = (1 - 1) * A_lvl.shape + j_3
A_lvl_2_q = A_lvl_ptr[A_lvl_q]
A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]
if A_lvl_2_q < A_lvl_2_q_stop
A_lvl_2_i1 = A_lvl_idx[A_lvl_2_q_stop - 1]
else
A_lvl_2_i1 = 0
end
phase_stop = min(A_lvl_2_i1, A_lvl_2.shape)
if phase_stop >= 1
if A_lvl_idx[A_lvl_2_q] < 1
A_lvl_2_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)
end
while true
A_lvl_2_i = A_lvl_idx[A_lvl_2_q]
if A_lvl_2_i < phase_stop
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
s_val = A_lvl_3_val + s_val
A_lvl_2_q += 1
else
phase_stop_3 = min(phase_stop, A_lvl_2_i)
if A_lvl_2_i == phase_stop_3
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
s_val += A_lvl_3_val
A_lvl_2_q += 1
end
break
end
end
end
end
result = ()
s.val = s_val
result
end

When the loop order does not correspond to storage order, we call this discordant iteration. For example, if we swap the loop order in the above example, then Finch needs to randomly access each sparse column for each row i. We end up needing to find each (i, j) pair because we don't know whether it will be zero until we search for it. In all, this takes time O(n * m * log(nnz)), much less efficient! We shouldn't randomly access sparse arrays unless we really need to and they support it efficiently!

Note the double for loop in the following code

@finch_code for i=_, j=_ ; s[] += A[i, j] end # DISCORDANT, DO NOT DO THIS

# output

quote
s = (ex.bodies[1]).body.body.lhs.tns.bind
s_val = s.val
A_lvl = (ex.bodies[1]).body.body.rhs.tns.bind.lvl
A_lvl_2 = A_lvl.lvl
A_lvl_ptr = A_lvl_2.ptr
A_lvl_idx = A_lvl_2.idx
A_lvl_2_val = A_lvl_2.lvl.val
@warn "Performance Warning: non-concordant traversal of A[i, j] (hint: most arrays prefer column major or first index fast, run in fast mode to ignore this warning)"
for i_3 = 1:A_lvl_2.shape
for j_3 = 1:A_lvl.shape
A_lvl_q = (1 - 1) * A_lvl.shape + j_3
A_lvl_2_q = A_lvl_ptr[A_lvl_q]
A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]
if A_lvl_2_q < A_lvl_2_q_stop
A_lvl_2_i1 = A_lvl_idx[A_lvl_2_q_stop - 1]
else
A_lvl_2_i1 = 0
end
phase_stop = min(i_3, A_lvl_2_i1)
if phase_stop >= i_3
if A_lvl_idx[A_lvl_2_q] < i_3
A_lvl_2_q = Finch.scansearch(A_lvl_idx, i_3, A_lvl_2_q, A_lvl_2_q_stop - 1)
end
while true
A_lvl_2_i = A_lvl_idx[A_lvl_2_q]
if A_lvl_2_i < phase_stop
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
s_val = A_lvl_3_val + s_val
A_lvl_2_q += 1
else
phase_stop_3 = min(phase_stop, A_lvl_2_i)
if A_lvl_2_i == phase_stop_3
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
s_val += A_lvl_3_val
A_lvl_2_q += 1
end
break
end
end
end
end
end
result = ()
s.val = s_val
result
end

TL;DR: As a quick heuristic, if your array indices are all in alphabetical order, then the loop indices should be reverse alphabetical.

Appropriate Fill Values

The @finch macro requires the user to specify an output format. This is the most flexibile approach, but can sometimes lead to densification unless the output fill value is appropriate for the computation.

For example, if A is m × n with nnz nonzeros, the following Finch kernel will densify B, filling it with m * n stored values:

A = Tensor(Dense(SparseList(Element(0.0))), fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)))
B = Tensor(Dense(SparseList(Element(0.0)))) #DO NOT DO THIS, B has the wrong fill value
@finch (B .= 0; for j=_, i=_; B[i, j] = A[i, j] + 1 end; return B)
countstored(B)

# output

12

Since A is filled with 0.0, adding 1 to the fill value produces 1.0. However, B can only represent a fill value of 0.0. Instead, we should specify 1.0 for the fill.

A = Tensor(Dense(SparseList(Element(0.0))), fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)))
B = Tensor(Dense(SparseList(Element(1.0))))
@finch (B .= 1; for j=_, i=_; B[i, j] = A[i, j] + 1 end; return B)
countstored(B)

# output

5

Static Versus Dynamic Values

In order to skip some computations, Finch must be able to determine the value of program variables. Continuing our above example, if we obscure the value of 1 behind a variable x, Finch can only determine that x has type Int, not that it is 1.

A = Tensor(Dense(SparseList(Element(0.0))), fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)))
B = Tensor(Dense(SparseList(Element(1.0))))
x = 1 #DO NOT DO THIS, Finch cannot see the value of x anymore
@finch (B .= 1; for j=_, i=_; B[i, j] = A[i, j] + x end; return B)
countstored(B)

# output

12

However, there are some situations where you may want a value to be dynamic. For example, consider the function saxpy(x, a, y) = x .* a .+ y. Because we do not know the value of a until we run the function, we should treat it as dynamic, and the following implementation is reasonable:

function saxpy(x, a, y)
z = Tensor(SparseList(Element(0.0)))
@finch (z .= 0; for i=_; z[i] = a * x[i] + y[i] end; return z)
end

Use Known Functions

Unless you declare the properties of your functions using Finch's User-Defined Functions interface, Finch doesn't know how they work. For example, using a lambda obscures the meaning of *.

A = Tensor(Dense(SparseList(Element(0.0))), fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)))
B = ones(4, 3)
C = Scalar(0.0)
f(x, y) = x * y # DO NOT DO THIS, Obscures *
@finch (C .= 0; for j=_, i=_; C[] += f(A[i, j], B[i, j]) end; return C)

# output

(C = Scalar{0.0, Float64}(16.5),)

Checking the generated code, we see that this code is indeed densifying (notice the for-loop which repeatedly evaluates f(B[i, j], 0.0)).

@finch_code (C .= 0; for j=_, i=_; C[] += f(A[i, j], B[i, j]) end; return C)

# output

quote
C = ((ex.bodies[1]).bodies[1]).tns.bind
A_lvl = (((ex.bodies[1]).bodies[2]).body.body.rhs.args[1]).tns.bind.lvl
A_lvl_2 = A_lvl.lvl
A_lvl_ptr = A_lvl_2.ptr
A_lvl_idx = A_lvl_2.idx
A_lvl_2_val = A_lvl_2.lvl.val
B = (((ex.bodies[1]).bodies[2]).body.body.rhs.args[2]).tns.bind
sugar_1 = size((((ex.bodies[1]).bodies[2]).body.body.rhs.args[2]).tns.bind)
B_mode1_stop = sugar_1[1]
B_mode2_stop = sugar_1[2]
B_mode1_stop == A_lvl_2.shape || throw(DimensionMismatch("mismatched dimension limits ($(B_mode1_stop) !=$(A_lvl_2.shape))"))
B_mode2_stop == A_lvl.shape || throw(DimensionMismatch("mismatched dimension limits ($(B_mode2_stop) !=$(A_lvl.shape))"))
C_val = 0
for j_4 = 1:B_mode2_stop
A_lvl_q = (1 - 1) * A_lvl.shape + j_4
A_lvl_2_q = A_lvl_ptr[A_lvl_q]
A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]
if A_lvl_2_q < A_lvl_2_q_stop
A_lvl_2_i1 = A_lvl_idx[A_lvl_2_q_stop - 1]
else
A_lvl_2_i1 = 0
end
phase_stop = min(B_mode1_stop, A_lvl_2_i1)
if phase_stop >= 1
i = 1
if A_lvl_idx[A_lvl_2_q] < 1
A_lvl_2_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)
end
while true
A_lvl_2_i = A_lvl_idx[A_lvl_2_q]
if A_lvl_2_i < phase_stop
for i_6 = i:-1 + A_lvl_2_i
val = B[i_6, j_4]
C_val = (Main).f(0.0, val) + C_val
end
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
val = B[A_lvl_2_i, j_4]
C_val += (Main).f(A_lvl_3_val, val)
A_lvl_2_q += 1
i = A_lvl_2_i + 1
else
phase_stop_3 = min(phase_stop, A_lvl_2_i)
if A_lvl_2_i == phase_stop_3
for i_8 = i:-1 + phase_stop_3
val = B[i_8, j_4]
C_val += (Main).f(0.0, val)
end
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
val = B[phase_stop_3, j_4]
C_val += (Main).f(A_lvl_3_val, val)
A_lvl_2_q += 1
else
for i_10 = i:phase_stop_3
val = B[i_10, j_4]
C_val += (Main).f(0.0, val)
end
end
i = phase_stop_3 + 1
break
end
end
end
phase_start_3 = max(1, 1 + A_lvl_2_i1)
if B_mode1_stop >= phase_start_3
for i_12 = phase_start_3:B_mode1_stop
val = B[i_12, j_4]
C_val += (Main).f(0.0, val)
end
end
end
C.val = C_val
(C = C,)
end


Type Stability

Julia code runs fastest when the compiler can infer the types of all intermediate values. Finch does not check that the generated code is type-stable. In situations where tensors have nonuniform index or element types, or the computation itself might involve multiple types, one should check that the output of @finch_kernel code is type-stable with @code_warntype.