Shrike.jl
Shrike
is a Julia package for building ensembles of random projection trees. Random projection trees are a generalization of KD-Trees and are used to quickly approximate nearest neighbors or build k-nearest-neighbor graphs. They conform to low dimensionality that is often present in high dimensional data.
The implementation here is based on the MRPT algorithm. This package also includes optimizations for knn-graph creation and has built-in support for multithreading.
Installation
To install just type
] add https://github.com/djpasseyjr/Shrike.jl
in the REPL or
using Pkg
Pkg.add(path="https://github.com/djpasseyjr/Shrike.jl")
Build an Index
To build an ensemble of random projection trees use the ShrikeIndex
type.
using Shrike
maxk = 100
X = rand(100, 10000)
shi = ShrikeIndex(X, maxk; depth=8, ntrees=10)
The type accepts a matrix of data, X
where each column represents a datapoint.
maxk
represents the maximum number of nearest neighbors you will be able to find with this index.
maxk
is used to set a safe depth
for the tree. You can also construct an index without this parameter if you need to.
depth
describes the number of times each random projection tree will split the data. Leaf nodes in the tree contain aboutnpoints / 2^depth
data points. Increasingdepth
increases speed but decreases accuracy. By default, the index sets depth as large as possible.ntrees
controls the number of trees in the ensemble. More trees means more accuracy but more memory.
In this case, since we need an index that can find the 100 nearest neighbors, setting depth
equal to 8 will result in some leaf nodes with less than 100 points. The index will infer this using maxk
and set the depth
to be as large as possible given maxk
. In this case, depth = 6
.
To query the index for approximte 10 nearest neighbors use:
k = 10
q = X[:, 1]
approx_nn = ann(shi, q, k; vote_cutoff=2)
- The
vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. Each tree "votes" for the points a leaf node, so if there aren't many point in the leaves and there aren't many trees, the odds of a point receiving more than one vote is low. Increasingvote_cutoff
speeds up the algorithm but may reduce accuracy. Whendepth
is large andntrees
is less than 5, it is reccomended to setvote_cutoff = 1
.
KNN-Graphs
This package includes fast algorithms to generate k-nearest-neighbor graphs and has specialized functions for this purpose. It uses neighbor of neighbor exploration (outlined here) to efficiently improve the accuracy of a knn-graph.
Nearest neighbor graphs are used to give a sparse topology to large datasets. Their structure can be used to project the data onto a lower dimensional manifold, to cluster datapoints with community detection algorithms or to preform other analyses.
To generate nearest neighbor graphs:
using Shrike
X = rand(100, 10000)
shi = ShrikeIndex(X; depth=6, ntrees=5)
k = 10
g = knngraph(shi, k; vote_cutoff=1, ne_iters=1, gtype=SimpleDiGraph)
- The
vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. ne_iters
controlls how many iterations of neighbor exploration the algorithm will undergo. Successive iterations are increasingly fast. It is reccomened to use more iterations of neighbor exploration when the number of trees is small and less when many trees are used.- The
gtype
parameter allows the user to specify aLightGraphs.jl
graph type to return.gtype=identity
returns a sparse adjacency matrix.
If an array of nearest neighbor indices is preferred,
nn = allknn(shi, k; vote_cutoff=1, ne_iters=0)
can be used to generate an shi.npoints
xk
array of integer indexes where nn[i, :]
corresponds to the nearest neighbors of X[:, i]
. The keyword arguments work in the same way as in knngraph
(outlined above).
Threading
Shrike
has built in support for multithreading. To allocate multiple threads, start julia
with the --threads
flag:
user@sys:~$ julia --threads 4
To see this at work, consider a small scale example:
user@sys:~$ cmd="using Shrike; shi=ShrikeIndex(rand(100, 10000)); @time knngraph(shi, 10, ne_iters=1)"
user@sys:~$ julia -e "$cmd"
12.373127 seconds (8.66 M allocations: 4.510 GiB, 6.85% gc time, 18.88% compilation time)
user@sys:~$ julia --threads 4 -e "$cmd"
6.306410 seconds (8.67 M allocations: 4.498 GiB, 13.12% gc time, 31.64% compilation time)
(This assumes that Shrike
is installed.)
Benchmark
This package was compared to the original mrpt
C++ implementation (on which this algorithm was based), annoy
, a popular package for approximate nearest neighbors, and NearestNeighbors.jl
, a Julia package for nearest neighbor search. The benchmarks were written in the spirit of ann-benchmarks
, a repository for comparing different approximate nearest neighbor algorithms. The datasets used for the benchmark were taken directly from ann-benchmarks
. The following are links to the HDF5 files in question: FashionMNIST, SIFT, MNIST and GIST. The benchmarks below were run on a compute cluster, restricting all algorithms to a single thread.
In this plot, up and to the right is better. (Faster queries, better recall). Each point represents a parameter combination. For a full documentation of parameters run and timing methods consult the original scripts located in the benchmark/
directory.
This plot illustrates how for this dataset, on most parameter combinations, Shrike
has better preformance. Compared to SIFT, below, where some parameter combinations are not as strong. We speculate that this has to do with the high dimensionality of points in FashionMNIST (d=784), compared to the lower dimensionality of SIFT (d=128).
It is important to note that NearestNeighbors.jl
was designed to return the exact k-nearest-neighbors as quickly as possible, and does not approximate, hence the high accuracy and lower speed.
The takeaway here is that Shrike
is fast! It is possibly a little faster than the original C++ implementation. Go Julia! We should note, that Shrike
was not benchmarked against state of the art algorithms for approximate nearest neighbor search. These algorithms are faster than annoy
and mrpt
, but unfortunately, the developers of Shrike
aren't familiar with these algorithms.
Function Documentation
Shrike.ShrikeIndex
— MethodShrikeIndex(data::AbstractArray{T, 2}, max)k; depth::Union{Int, Float64}=Inf, ntrees::Int=5) -> shi
Keyword argument version of the constructor that includes intended number of nearest neighbors.
If the default depth
is used, the constructor sets the tree depth as deep as possible given max_k
. This way, the accuracy/memory tradeoff is determined directly by ntrees
and the desired vote_cutoff
(vote_cutoff
is a parameter passed to ann
or knngraph
).
If an argument is passed for depth
, constructor attempts to use the supplied depth
, but guarentees that the depth of the tree is shallow enough to ensure that each leaf has at least k points. (Without this check, the index may return less than k neighbors when queried.)
Parameters
data
: A (dxn) array. Each column is a datapoint with dimensiond
.max_k
: The maximum number of neighbors that will be queried. If intend
to use the ShrikeIndex
to approximate at most 10 nearest neigbors of a point, set max_k = 10
. This argument is used to infer the deepest tree depth possible so as to maximize speed,
Keyword Arguments
ntrees
: The number of trees in the index. More trees means more accuracy,
more memory and less speed. Use this to tune the speed/accuracy tradeoff.
depth
: The number of splits in the tree. Depth of 0 means only a root,
depth of 1 means root has two children, etc..
Shrike.ShrikeIndex
— MethodShrikeIndex(data::Array{T, 2}, depth::Int, ntrees::Int) where T -> ensemble
Constructor for ensemble of sparse random projection trees with voting. Returns ShrikeIndex
type. (An ensemble of multiple random projection trees.)
** Type Fields**
data::Array{T, 2}
: Contains all data pointsnpoints::Int
: Number of data pointsndims::Int
: Dimensionality of the datadepth::Int
: maximum depth of the tree. (Depth of 0 means only a root, depth of 1 means root has two children)ntrees::Int
: Number of trees to makerandom_vectors::AbstractArray
: The random projections used to make the treesplits::Array{T, 2}
: The split values for each node in each tree stored in a 2D arrayindexes::Array{Array{Int,1}, 2}
: 2D array of datapoint indexes at each leaf node in each tree.
Note that RP forest does not store indexes at non-leaf nodes.
Follows the implementation outlined in:
Fast Nearest Neighbor Search through Sparse Random Projections and Voting. Ville Hyvönen, Teemu Pitkänen, Sotirios Tasoulis, Elias Jääsaari, Risto Tuomainen, Liang Wang, Jukka Ilmari Corander, Teemu Roos. Proceedings of the 2016 IEEE Conference on Big Data (2016)
with some modifications.
Shrike.ann
— Methodapprox_knn(shi::ShrikeIndex{T}, q::Array{T, 2}, k::Int; vote_cutoff=1) where T -> knn_idx
For a query point q
, find the approximate k
nearest neighbors from the data stored in the the ShrikeIndex. The vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasing vote_cutoff
speeds up the algorithm but may reduce accuracy.
Shrike.knngraph
— Methodknngraph(shi::ShrikeIndex{T}, k::Int, vote_cutoff; vote_cutoff::Int=1, ne_iters::Int=0, gtype::G) where {T, G} -> g
Returns a graph with shi.npoints
node and k * shi.npoints
edges datapoints conneceted to nearest neighbors
Parameters
shi
: random forest of the desired datak
: the desired number of nearest neighborsvote_cutoff
: signifies how many "votes" a point needs in order to be included
in a linear search through leaf nodes. Increasing vote_cutoff
speeds up the algorithm but may reduce accuracy. Passing too large of a vote_cutoff
results in the algorithm resetting vote_cutoff
to equal the number of trees.
ne_iters
: assigns the number of iterations of neighbor exploration to use. Defaults to zero.
Neighbor exploration is a way to increse knn-graph accuracy.
gtype
is the type of graph to construct. Defaults toSimpleDiGraph
.gtype=identity
returns a sparse adjacency matrix.
Missing docstring for explore(i::Int, data::AbstractArray{T}, ann::Array{NeighborExplorer{T}, 1}) where T
. Check Documenter's build log for details.
Shrike.allknn
— Methodallknn(shi::ShrikeIndex{T}, k::Int; vote_cutoff::Int=1, ne_iters::Int=0) where T -> approxnn_array
Returns a shi.npoints
by k
array of approximate nearest neighbor indexes. That is, approxnn_array[i,:]
contains the indexes of the k nearest neighbors of shi.data[:, i]
.
Parameters
- The
ne_iters
assigns the number of iterations of neighbor exploration to use.
Neighbor exploration is an inexpensive way to increase accuracy.
2, The vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasing vote_cutoff
speeds up the algorithm but may reduce accuracy. Passing too large of a vote_cutoff
results in the algorithm resetting vote_cutoff
to equal the number of trees.
Missing docstring for traverse_tree(shi::ShrikeIndex{T}, x::Array{T, 2}) where T
. Check Documenter's build log for details.