Node and Tree Operations

This example demonstrates how to create and manipulate expression trees using the Node type.

First, let's create a node to reference feature=1 of our dataset:

using DynamicExpressions, Random

x = Node{Float64}(; feature=1)
x1

We can also create values, using val:

const_1 = Node{Float64}(; val=1.0)
1.0

Now, let's declare some operators to use in our expression tree.

Note that the declaration of the OperatorEnum updates a global mapping from operators to their index in a list. This is purely for convenience, and most of the time, you would either operate directly on the OperatorEnum, like with eval_tree_array, or use Expression objects to store them alongside the expression.

operators = OperatorEnum(; unary_operators=(sin, exp), binary_operators=(+, -, *, /))
OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(exp)}}((+, -, *, /), (sin, exp))

Now, let's create another variable

y = Node{Float64}(; feature=2)
x2

And we can now create expression trees:

tree = (x + y) * const_1 - sin(x)
((x1 + x2) * 1.0) - sin(x1)

The type of this is the same as the type of the variables and constants, meaning we have type stability:

typeof(tree), typeof(x)
(Node{Float64}, Node{Float64})

We can also just use scalars directly:

tree2 = 2x - sin(x)
(2.0 * x1) - sin(x1)

As you have noticed, the tree is printed as an expression. We can control this with the string_tree function, which also lets us pass the operators explicitly:

string_tree(tree, operators; variable_names=["x", "y"])
"((x + y) * 1.0) - sin(x)"

This also lets us control how each branch node and leaf node (variable/constant) is printed in the tree.

There are a lot of operations you can do on tree objects, such as evaluating them over batched data:

rng = Random.MersenneTwister(0)
tree2(randn(rng, Float64, 2, 5), operators)
5-element Vector{Float64}:
  0.730116119600978
 -0.3602934782862816
  0.6196873148394606
  0.06499319903820032
 -0.5365732772138014

Now, how does this actually work? How do these functions traverse the tree?

The core operation is the tree_mapreduce function, which applies a function to each node in the tree, and then combines the results. Unlike a standard mapreduce, the tree_mapreduce allows you to specify different maps for branch nodes and leaf nodes. Also unlike a mapreduce, the reduction function needs to handle a variable number of inputs – it takes the mapped branch node, as well as all of the mapped children.

Let's see an example. Say we just want to count the nodes in the tree:

num_nodes = tree_mapreduce(node -> 1, +, tree)
8

Here, the + handles both the cases of 1 child and 2 children. Here, we didn't need to specify a custom branch function, but we could do that too:

num_leafs = tree_mapreduce(leaf_node -> 1, branch_node -> 0, +, tree)
4

This counts the number of leaf nodes in the tree. For tree, this was x, y, const_1, and x.

You can access fields of the Node type here to create more complex operations, just be careful to not access undefined fields (be sure to read the API specification).

Most operators can be built with this simple pattern, even including evaluation of the tree, and printing of expressions. (It also allows for graph-like expressions like GraphNode via a f_on_shared keyword.)

As a more complex example, let's compute the depth of a tree. Here, we need to use a more complicated reduction operation – the max:

depth = tree_mapreduce(
    node -> 1, (parent, children...) -> 1 + max(children...), x + sin(sin(exp(x)))
)
5

Here, the max handles both the cases of 1 child and 2 children. The parent node contributes 1 at each depth. Note that the inputs to the reduction are already mapped to 1.

Many operations do not need to handle branching, and thus, many of the typical operations on collections in Julia are available. For example, we can collect each node in the tree into a list:

collect(tree)
8-element Vector{Node{Float64}}:
 ((x1 + x2) * 1.0) - sin(x1)
 (x1 + x2) * 1.0
 x1 + x2
 x1
 x2
 1.0
 sin(x1)
 x1

Note that the first node in this list is the root note, which is the subtraction operation:

tree == first(collect(tree))
true

We can look at the operator:

tree.degree, tree.op
(0x02, 0x02)

And compare it to our list:

operators.binops
(+, -, *, /)

Many other collection operations are available. For example, we can aggregate a relationship over each node:

sum(node -> node.degree == 0 ? 1.5 : 0.0, tree)
6.0

We can even use any which has an early exit from the depth-first tree traversal:

any(node -> node.degree == 2, tree)
true

We can also randomly sample nodes, using NodeSampler, which permits filters:

rand(rng, NodeSampler(; tree, filter=node -> node.degree == 1))
sin(x1)

This page was generated using Literate.jl.