BootlegCassette.jl

BootlegCassette.jl is a quick and dirty package that tries to mimic the interface of Cassette.jl using IRTools.jl under the hood. This isn't a great implementation, but provided you do not use tagging and only use @context, ovderdub, prehook, posthook and recurse, BootlegCassette.jl should work as a drop-in replacement for Cassette.jl. This is mostly only relevant for Julia v1.6+ where Cassette is broken at the time of writing.

BootlegCassette.jl is currently signigicantly slower than regular Cassette.jl and has a different mechanism for setting recursion barriers. Currently, it's set by default to not recurse into functions from the Core module and also will leave the functions isdispatchtuple, eltype, convert, getproperty, and throw alone. This can be modified, but it's modified in a different way from in standard non-bootleg Cassette.jl

Examples

using BootlegCassette: BootlegCassette, @context, prehook, overdub, posthook, recurse
const Cassette = BootlegCassette

Cassette.@context Ctx 
Cassette.prehook(::Ctx, f, args...) = println(f, args)
Cassette.overdub(Ctx(), /, 1, 2)

#+RESULTS
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
Cassette.prehook(::Ctx, f, args...) = nothing
Cassette.prehook(::Ctx{Val{T}}, f, arg::T, rest...) where {T} = println(f, (arg, rest...))
Cassette.overdub(Ctx(metadata=Val(Int)), /, 1, 2)

#+RESULTS
 float(1,)
 AbstractFloat(1,)
 Float64(1,) 
 float(2,)
 AbstractFloat(2,)
 Float64(2,)
 0.5
Cassette.overdub(Ctx(metadata=Val(DataType)), /, 1, 2)

#+RESULTS
 sitofp(Float64, 1)
 sitofp(Float64, 2)
 0.5
Cassette.@context TraceCtx

mutable struct Trace
    current::Vector{Any}
    stack::Vector{Any}
    Trace() = new(Any[], Any[])
end

function enter!(t::Trace, args...)
    pair = args => Any[]
    push!(t.current, pair)
    push!(t.stack, t.current)
    t.current = pair.second
    return nothing
end

function exit!(t::Trace)
    t.current = pop!(t.stack)
    return nothing
end

Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...)
Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata)

trace = Trace()
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

trace.current == Any[
    (f,x,y,z) => Any[
        (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
        (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
        (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
    ]
]

#+RESULTS
true
Cassette.@context SinToCosCtx

Cassette.overdub(::SinToCosCtx, ::typeof(sin), x) = cos(x)

x = rand(10)
y = Cassette.overdub(SinToCosCtx(), sum, i -> cos(i) + sin(i), x)
y == sum(i -> 2 * cos(i), x)

#+RESULTS
true
fib(x) = x < 3 ? 1 : fib(x - 2) + fib(x - 1)
fibtest(n) = fib(2 * n) + n

@context MemoizeCtx

function Cassette.overdub(ctx::MemoizeCtx, ::typeof(fib), x)
    result = get(ctx.metadata, x, 0)
    if result === 0
        result = recurse(ctx, fib, x)
        ctx.metadata[x] = result
    end
    return result
end

ctx = MemoizeCtx(metadata=Dict{Int, Int}())

@time overdub(ctx, fibtest, 20)
@time overdub(ctx, fibtest, 20)
@time fibtest(20)

#+RESULTS
   0.188974 seconds (361.71 k allocations: 21.705 MiB, 7.02% gc time, 99.87% compilation time)
   0.000010 seconds (2 allocations: 32 bytes)
   0.318917 seconds
 102334175

The final example from https://julia.mit.edu/Cassette.jl/stable/contextualdispatch.html does not currently work.