Shrike.jl

Build Status codecov

Random Projection Splits

Shrike is a Julia package for building ensembles of random projection trees. Random projection trees are a generalization of KD-Trees and are used to quickly approximate nearest neighbors or build k-nearest-neighbor graphs. They conform to low dimensionality that is often present in high dimensional data.

The implementation here is based on the MRPT algorithm. This package also includes optimizations for knn-graph creation and has built-in support for multithreading.

Installation

To install just type

] add https://github.com/djpasseyjr/Shrike.jl

in the REPL or

using Pkg
Pkg.add(path="https://github.com/djpasseyjr/Shrike.jl")

Build an Index

To build an ensemble of random projection trees use the ShrikeIndex type.

using Shrike
maxk = 100
X = rand(100, 10000)
shi = ShrikeIndex(X, maxk; depth=8, ntrees=10)

The type accepts a matrix of data, X where each column represents a datapoint.

  1. maxk represents the maximum number of nearest neighbors you will be able to find with this index.

maxk is used to set a safe depth for the tree. You can also construct an index without this parameter if you need to.

  1. depth describes the number of times each random projection tree will split the data. Leaf nodes in the tree contain about npoints / 2^depth data points. Increasing depth increases speed but decreases accuracy. By default, the index sets depth as large as possible.
  2. ntrees controls the number of trees in the ensemble. More trees means more accuracy but more memory.

In this case, since we need an index that can find the 100 nearest neighbors, setting depth equal to 8 will result in some leaf nodes with less than 100 points. The index will infer this using maxk and set the depth to be as large as possible given maxk. In this case, depth = 6.

To query the index for approximte 10 nearest neighbors use:

k = 10
q = X[:, 1]
approx_nn = ann(shi, q, k; vote_cutoff=2)
  1. The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search. Each tree "votes" for the points a leaf node, so if there aren't many point in the leaves and there aren't many trees, the odds of a point receiving more than one vote is low. Increasing vote_cutoff speeds up the algorithm but may reduce accuracy. When depth is large and ntrees is less than 5, it is reccomended to set vote_cutoff = 1.

KNN-Graphs

This package includes fast algorithms to generate k-nearest-neighbor graphs and has specialized functions for this purpose. It uses neighbor of neighbor exploration (outlined here) to efficiently improve the accuracy of a knn-graph.

Nearest neighbor graphs are used to give a sparse topology to large datasets. Their structure can be used to project the data onto a lower dimensional manifold, to cluster datapoints with community detection algorithms or to preform other analyses.

To generate nearest neighbor graphs:

using Shrike
X = rand(100, 10000)
shi = ShrikeIndex(X; depth=6, ntrees=5)
k = 10
g = knngraph(shi, k; vote_cutoff=1, ne_iters=1, gtype=SimpleDiGraph)
  1. The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search.
  2. ne_iters controlls how many iterations of neighbor exploration the algorithm will undergo. Successive iterations are increasingly fast. It is reccomened to use more iterations of neighbor exploration when the number of trees is small and less when many trees are used.
  3. The gtype parameter allows the user to specify a LightGraphs.jl graph type to return. gtype=identity returns a sparse adjacency matrix.

If an array of nearest neighbor indices is preferred,

nn = allknn(shi, k; vote_cutoff=1, ne_iters=0)

can be used to generate an shi.npointsxk array of integer indexes where nn[i, :] corresponds to the nearest neighbors of X[:, i]. The keyword arguments work in the same way as in knngraph (outlined above).

Threading

Shrike has built in support for multithreading. To allocate multiple threads, start julia with the --threads flag:

user@sys:~$ julia --threads 4

To see this at work, consider a small scale example:

user@sys:~$ cmd="using Shrike; shi=ShrikeIndex(rand(100, 10000)); @time knngraph(shi, 10, ne_iters=1)"
user@sys:~$ julia -e "$cmd"
  12.373127 seconds (8.66 M allocations: 4.510 GiB, 6.85% gc time, 18.88% compilation time)
user@sys:~$ julia  --threads 4 -e "$cmd"
  6.306410 seconds (8.67 M allocations: 4.498 GiB, 13.12% gc time, 31.64% compilation time)

(This assumes that Shrike is installed.)

Benchmark

This package was compared to the original mrpt C++ implementation (on which this algorithm was based), annoy, a popular package for approximate nearest neighbors, and NearestNeighbors.jl, a Julia package for nearest neighbor search. The benchmarks were written in the spirit of ann-benchmarks, a repository for comparing different approximate nearest neighbor algorithms. The datasets used for the benchmark were taken directly from ann-benchmarks. The following are links to the HDF5 files in question: FashionMNIST, SIFT, MNIST and GIST. The benchmarks below were run on a compute cluster, restricting all algorithms to a single thread.

FashionMNIST Speed Comparison

In this plot, up and to the right is better. (Faster queries, better recall). Each point represents a parameter combination. For a full documentation of parameters run and timing methods consult the original scripts located in the benchmark/ directory.

This plot illustrates how for this dataset, on most parameter combinations, Shrike has better preformance. Compared to SIFT, below, where some parameter combinations are not as strong. We speculate that this has to do with the high dimensionality of points in FashionMNIST (d=784), compared to the lower dimensionality of SIFT (d=128).

SIFT Speed Comparison

It is important to note that NearestNeighbors.jl was designed to return the exact k-nearest-neighbors as quickly as possible, and does not approximate, hence the high accuracy and lower speed.

The takeaway here is that Shrike is fast! It is possibly a little faster than the original C++ implementation. Go Julia! We should note, that Shrike was not benchmarked against state of the art algorithms for approximate nearest neighbor search. These algorithms are faster than annoy and mrpt, but unfortunately, the developers of Shrike aren't familiar with these algorithms.

Function Documentation

Shrike.ShrikeIndexMethod
ShrikeIndex(data::AbstractArray{T, 2}, max)k; depth::Union{Int, Float64}=Inf, ntrees::Int=5) -> shi

Keyword argument version of the constructor that includes intended number of nearest neighbors.

If the default depth is used, the constructor sets the tree depth as deep as possible given max_k. This way, the accuracy/memory tradeoff is determined directly by ntrees and the desired vote_cutoff (vote_cutoff is a parameter passed to ann or knngraph).

If an argument is passed for depth, constructor attempts to use the supplied depth, but guarentees that the depth of the tree is shallow enough to ensure that each leaf has at least k points. (Without this check, the index may return less than k neighbors when queried.)

Parameters

  1. data: A (dxn) array. Each column is a datapoint with dimension d.
  2. max_k: The maximum number of neighbors that will be queried. If intend

to use the ShrikeIndex to approximate at most 10 nearest neigbors of a point, set max_k = 10. This argument is used to infer the deepest tree depth possible so as to maximize speed,

Keyword Arguments

  1. ntrees: The number of trees in the index. More trees means more accuracy,

more memory and less speed. Use this to tune the speed/accuracy tradeoff.

  1. depth: The number of splits in the tree. Depth of 0 means only a root,

depth of 1 means root has two children, etc..

Shrike.ShrikeIndexMethod
ShrikeIndex(data::Array{T, 2}, depth::Int, ntrees::Int) where T -> ensemble

Constructor for ensemble of sparse random projection trees with voting. Returns ShrikeIndex type. (An ensemble of multiple random projection trees.)

** Type Fields**

  • data::Array{T, 2}: Contains all data points
  • npoints::Int: Number of data points
  • ndims::Int: Dimensionality of the data
  • depth::Int: maximum depth of the tree. (Depth of 0 means only a root, depth of 1 means root has two children)
  • ntrees::Int: Number of trees to make
  • random_vectors::AbstractArray: The random projections used to make the tree
  • splits::Array{T, 2}: The split values for each node in each tree stored in a 2D array
  • indexes::Array{Array{Int,1}, 2}: 2D array of datapoint indexes at each leaf node in each tree.

Note that RP forest does not store indexes at non-leaf nodes.

Follows the implementation outlined in:

Fast Nearest Neighbor Search through Sparse Random Projections and Voting. Ville Hyvönen, Teemu Pitkänen, Sotirios Tasoulis, Elias Jääsaari, Risto Tuomainen, Liang Wang, Jukka Ilmari Corander, Teemu Roos. Proceedings of the 2016 IEEE Conference on Big Data (2016)

with some modifications.

Shrike.annMethod
approx_knn(shi::ShrikeIndex{T}, q::Array{T, 2}, k::Int; vote_cutoff=1) where T -> knn_idx

For a query point q, find the approximate k nearest neighbors from the data stored in the the ShrikeIndex. The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasing vote_cutoff speeds up the algorithm but may reduce accuracy.

Shrike.knngraphMethod
knngraph(shi::ShrikeIndex{T}, k::Int, vote_cutoff; vote_cutoff::Int=1, ne_iters::Int=0, gtype::G) where {T, G} -> g

Returns a graph with shi.npoints node and k * shi.npoints edges datapoints conneceted to nearest neighbors

Parameters

  1. shi: random forest of the desired data
  2. k: the desired number of nearest neighbors
  3. vote_cutoff: signifies how many "votes" a point needs in order to be included

in a linear search through leaf nodes. Increasing vote_cutoff speeds up the algorithm but may reduce accuracy. Passing too large of a vote_cutoff results in the algorithm resetting vote_cutoff to equal the number of trees.

  1. ne_iters: assigns the number of iterations of neighbor exploration to use. Defaults to zero.

Neighbor exploration is a way to increse knn-graph accuracy.

  1. gtype is the type of graph to construct. Defaults to SimpleDiGraph. gtype=identity returns a sparse adjacency matrix.
Missing docstring.

Missing docstring for explore(i::Int, data::AbstractArray{T}, ann::Array{NeighborExplorer{T}, 1}) where T. Check Documenter's build log for details.

Shrike.allknnMethod
allknn(shi::ShrikeIndex{T}, k::Int; vote_cutoff::Int=1, ne_iters::Int=0) where T -> approxnn_array

Returns a shi.npoints by k array of approximate nearest neighbor indexes. That is, approxnn_array[i,:] contains the indexes of the k nearest neighbors of shi.data[:, i].

Parameters

  1. The ne_iters assigns the number of iterations of neighbor exploration to use.

Neighbor exploration is an inexpensive way to increase accuracy.

2, The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasing vote_cutoff speeds up the algorithm but may reduce accuracy. Passing too large of a vote_cutoff results in the algorithm resetting vote_cutoff to equal the number of trees.

Missing docstring.

Missing docstring for traverse_tree(shi::ShrikeIndex{T}, x::Array{T, 2}) where T. Check Documenter's build log for details.