Finch Compilation

Finch Notation

Finch.FinchNotation.FinchNodeType
FinchNode

A Finch IR node. Finch uses a variant of Concrete Index Notation as an intermediate representation.

The FinchNode struct represents many different Finch IR nodes. The nodes are differentiated by a FinchNotation.FinchNodeKind enum.

Finch.FinchNotation.finch_leafFunction
finch_leaf(x)

Return a terminal finch node wrapper around x. A convenience function to determine whether x should be understood by default as a literal, value, or virtual.

Program Instances

Finch relies heavily on Julia's metaprogramming capabilities ( macros and generated functions in particular) to produce code. To review briefly, a macro allows us to inspect the syntax of it's arguments and generate replacement syntax. A generated function allows us to inspect the type of the function arguments and produce code for a function body.

In normal Finch usage, we might call Finch as follows:

julia> C = Tensor(SparseList(Element(0)));

julia> A = Tensor(SparseList(Element(0)), [0, 2, 0, 0, 3]);


julia> B = Tensor(Dense(Element(0)), [11, 12, 13, 14, 15]);

julia> @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end);


julia> C
SparseList (0) [1:5]
├─ [2]: 24
└─ [5]: 45

The @macroexpand macro allows us to see the result of applying a macro. Let's examine what happens when we use the @finch macro (we've stripped line numbers from the result to clean it up):

julia> (@macroexpand @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end)) |> Finch.striplines |> Finch.regensym
quote
    _res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0)), begin
                        let i = index_instance(i)
                            (Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
                        end
                    end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))), (;))
    begin
        C = _res_1[:C]
    end
    begin
        _res_1
    end
end

In the above output, @finch creates an AST of program instances, then calls Finch.execute on it. A program instance is a struct that contains the program to be executed along with its arguments. Although we can use the above constructors (e.g. loop_instance) to make our own program instance, it is most convenient to use the unexported macro Finch.finch_program_instance:

julia> using Finch: @finch_program_instance

julia> prgm = Finch.@finch_program_instance (C .= 0; for i=_; C[i] = A[i] * B[i] end; return C)
Finch program instance: begin
  tag(C, Tensor(SparseList(Element(0)))) .= 0
  for i = Dimensionless()
    tag(C, Tensor(SparseList(Element(0))))[tag(i, i)] <<initwrite>>= tag(*, *)(tag(A, Tensor(SparseList(Element(0))))[tag(i, i)], tag(B, Tensor(Dense(Element(0))))[tag(i, i)])
  end
  return (tag(C, Tensor(SparseList(Element(0)))))
end

As we can see, our program instance contains not only the AST to be executed, but also the data to execute the program with. The type of the program instance contains only the program portion; there may be many program instances with different inputs, but the same program type. We can run our program using Finch.execute, which returns a NamedTuple of outputs.

julia> typeof(prgm)
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}

julia> C = Finch.execute(prgm).C
SparseList (0) [1:5]
├─ [2]: 24
└─ [5]: 45

This functionality is sufficient for building finch kernels programatically. For example, if we wish to define a function pointwise_sum() that takes the pointwise sum of a variable number of vector inputs, we might implement it as follows:

julia> function pointwise_sum(As...)
           B = Tensor(Dense(Element(0)))
           isempty(As) && return B
           i = Finch.FinchNotation.index_instance(:i)
           A_vars = [Finch.FinchNotation.tag_instance(Finch.FinchNotation.variable_instance(Symbol(:A, n)), As[n]) for n in 1:length(As)]
           #create a list of variable instances with different names to hold the input tensors
           ex = @finch_program_instance 0
           for A_var in A_vars
               ex = @finch_program_instance $A_var[i] + $ex
           end
           prgm = @finch_program_instance (B .= 0; for i=_; B[i] = $ex end; return B)
           return Finch.execute(prgm).B
       end
pointwise_sum (generic function with 1 method)

julia> pointwise_sum([1, 2], [3, 4])
Dense [1:2]
├─ [1]: 4
└─ [2]: 6

Virtualization

Finch generates different code depending on the types of the arguments to the program. For example, in the following program, A and B have different types, and so the code generated for the loop is different. In order to execute a program, Finch builds a typed AST (Abstract Syntax Tree), then calls Finch.execute on it. The AST object is just an instance of a program to execute, and contains the program to execute along with the data to execute it. The type of the program instance contains only the program portion; there may be many program instances with different inputs, but the same program type. During compilation, Finch uses the type of the program to construct a more ergonomic representation, which is then used to generate code. This process is called "virtualization". All of the Finch AST nodes have both instance and virtual representations. For example, the literal 42 is represented as Finch.FinchNotation.LiteralInstance(42) and then virtualized to literal(42). The virtualization process is implemented by the virtualize function.

julia> A = Tensor(SparseList(Element(0)), [0, 2, 0, 0, 3]);

julia> B = Tensor(Dense(Element(0)), [11, 12, 13, 14, 15]);

julia> s = Scalar(0);

julia> typeof(A)
Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}

julia> typeof(B)
Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}

julia> inst = Finch.@finch_program_instance begin
           for i = _
               s[] += A[i]
           end
       end
Finch program instance: for i = Dimensionless()
  tag(s, Scalar{0, Int64}(0))[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end

julia> typeof(inst)
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}

julia> Finch.virtualize(:inst, typeof(inst), Finch.JuliaContext())
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
  tag(s, virtual(Finch.VirtualScalar))[] <<tag(+, +)>>= tag(A, virtual(Finch.VirtualFiber{Finch.VirtualSparseListLevel}))[tag(i, i)]
end

julia> @finch_code begin
           for i = _
               s[] += A[i]
           end
       end
quote
    s = (ex.bodies[1]).body.lhs.tns.bind
    s_val = s.val
    A_lvl = (ex.bodies[1]).body.rhs.tns.bind.lvl
    A_lvl_ptr = A_lvl.ptr
    A_lvl_idx = A_lvl.idx
    A_lvl_val = A_lvl.lvl.val
    result = nothing
    A_lvl_q = A_lvl_ptr[1]
    A_lvl_q_stop = A_lvl_ptr[1 + 1]
    if A_lvl_q < A_lvl_q_stop
        A_lvl_i1 = A_lvl_idx[A_lvl_q_stop - 1]
    else
        A_lvl_i1 = 0
    end
    phase_stop = min(A_lvl_i1, A_lvl.shape)
    if phase_stop >= 1
        if A_lvl_idx[A_lvl_q] < 1
            A_lvl_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_q, A_lvl_q_stop - 1)
        end
        while true
            A_lvl_i = A_lvl_idx[A_lvl_q]
            if A_lvl_i < phase_stop
                A_lvl_2_val = A_lvl_val[A_lvl_q]
                s_val = A_lvl_2_val + s_val
                A_lvl_q += 1
            else
                phase_stop_3 = min(A_lvl_i, phase_stop)
                if A_lvl_i == phase_stop_3
                    A_lvl_2_val = A_lvl_val[A_lvl_q]
                    s_val += A_lvl_2_val
                    A_lvl_q += 1
                end
                break
            end
        end
    end
    result = ()
    s.val = s_val
    result
end

julia> @finch_code begin
           for i = _
               s[] += B[i]
           end
       end
quote
    s = (ex.bodies[1]).body.lhs.tns.bind
    s_val = s.val
    B_lvl = (ex.bodies[1]).body.rhs.tns.bind.lvl
    B_lvl_val = B_lvl.lvl.val
    result = nothing
    for i_3 = 1:B_lvl.shape
        B_lvl_q = (1 - 1) * B_lvl.shape + i_3
        B_lvl_2_val = B_lvl_val[B_lvl_q]
        s_val = B_lvl_2_val + s_val
    end
    result = ()
    s.val = s_val
    result
end

Users can also create their own virtual nodes to represent their custom types. These types may contain constants and other virtuals, as well as reference variables in the scope of the executing context. Any aspect of virtuals visible to Finch should be considered immutable, but virtuals may reference mutable variables in the scope of the executing context.

Finch.virtualizeFunction
virtualize(ex, T, ctx, [tag])

Return the virtual program corresponding to the Julia expression ex of type T in the JuliaContext ctx. Implementaters may support the optional tag argument is used to name the resulting virtual variable.

Working with Finch IR

Calling print on a finch program or program instance will print the structure of the program as one would call constructors to build it. For example,

julia> prgm_inst = Finch.@finch_program_instance for i = _
            s[] += A[i]
        end;


julia> println(prgm_inst)
loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), literal_instance(Finch.FinchNotation.Updater())), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal_instance(Finch.FinchNotation.Reader()), tag_instance(variable_instance(:i), index_instance(i)))))

julia> prgm_inst
Finch program instance: for i = Dimensionless()
  tag(s, Scalar{0, Int64}(0))[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end

julia> prgm = Finch.@finch_program for i = _
               s[] += A[i]
           end;


julia> println(prgm)
loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i))))

julia> prgm
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
  Scalar{0, Int64}(0)[] <<+>>= Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))[i]
end

Both the virtual and instance representations of Finch IR define SyntaxInterface.jl and AbstractTrees.jl representations, so you can use the standard operation, arguments, istree, and children functions to inspect the structure of the program, as well as the rewriters defined by RewriteTools.jl

julia> using Finch.FinchNotation;


julia> PostOrderDFS(prgm)
PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i)))))

julia> (@capture prgm loop(~idx, ~ext, ~val))
true

julia> idx
Finch program: i 

Tensor internals

Tensor levels are implemented using the following methods:

Finch.defaultFunction
default(arr)

Return the initializer for arr. For SparseArrays, this is 0. Often, the default value becomes the fill or background value of a tensor.

Finch.declare_level!Function
declare_level!(lvl, ctx, pos, init)

Initialize and thaw all fibers within lvl, assuming positions 1:pos were previously assembled and frozen. The resulting level has no assembled positions.

Finch.assemble_level!Function
assemble_level!(lvl, ctx, pos, new_pos)

Assemble and positions pos+1:new_pos in lvl, assuming positions 1:pos were previously assembled.

Finch.reassemble_level!Function
reassemble_level!(lvl, ctx, pos_start, pos_end)

Set the previously assempled positions from pos_start to pos_end to level_default(lvl). Not avaliable on all level types as this presumes updating.

Finch.freeze_level!Function
freeze_level!(lvl, ctx, pos, init)

Given the last reference position, pos, freeze all fibers within lvl assuming that we have potentially updated 1:pos.

Finch.level_ndimsFunction
level_ndims(::Type{Lvl})

The result of level_ndims(Lvl) defines ndims for all subfibers in a level of type Lvl.

Finch.level_sizeFunction
level_size(lvl)

The result of level_size(lvl) defines the size of all subfibers in the level lvl.

Finch.level_axesFunction
level_axes(lvl)

The result of level_axes(lvl) defines the axes of all subfibers in the level lvl.

Finch.level_eltypeFunction
level_eltype(::Type{Lvl})

The result of level_eltype(Lvl) defines eltype for all subfibers in a level of type Lvl.

Finch.level_defaultFunction
level_default(::Type{Lvl})

The result of level_default(Lvl) defines default for all subfibers in a level of type Lvl.

Virtual Tensor Methods

Finch.declare!Function
declare!(tns, ctx, init)

Declare the read-only virtual tensor tns in the context ctx with a starting value of init and return it. Afterwards the tensor is update-only.

Finch.instantiateFunction
instantiate(tns, ctx, mode, protos)

Return an object (usually a looplet nest) capable of unfurling the virtual tensor tns. Before executing a statement, each subsequent in-scope access will be initialized with a separate call to instantiate. protos is the list of protocols in each case.

The fallback for instantiate will iteratively move the last element of protos into the arguments of a function. This allows fibers to specialize on the last arguments of protos rather than the first, as Finch is column major.

Finch.freeze!Function
freeze!(tns, ctx)

Freeze the update-only virtual tensor tns in the context ctx and return it. This may involve trimming any excess overallocated memory. Afterwards, the tensor is read-only.

Finch.thaw!Function
thaw!(tns, ctx)

Thaw the read-only virtual tensor tns in the context ctx and return it. Afterwards, the tensor is update-only.

Finch.unfurlFunction
unfurl(tns, ctx, ext, protos...)

Return an array object (usually a looplet nest) for lowering the virtual tensor tns. ext is the extent of the looplet. protos is the list of protocols that should be used for each index, but one doesn't need to unfurl all the indices at once.