Reference

In this reference, you will find a detailed overview of the package API.

Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented.

β€” DiΓ‘taxis

In other words, you come here because you want to take a very close look at the code 🧐.

Content

Exported functions

CounterfactualExplanations.CounterfactualExplanation β€” Method
function CounterfactualExplanation(;
	x::AbstractArray,
	target::RawTargetType,
	data::CounterfactualData,
	M::Models.AbstractFittedModel,
	generator::Generators.AbstractGenerator,
	max_iter::Int = 100,
	num_counterfactuals::Int = 1,
	initialization::Symbol = :add_perturbation,
	generative_model_params::NamedTuple = (;),
	min_success_rate::AbstractFloat=0.99,
    converge_when::Symbol=:decision_threshold,
    invalidation_rate::AbstractFloat=0.5,
    learning_rate::AbstractFloat=1.0,
    variance::AbstractFloat=0.01,
)

Outer method to construct a CounterfactualExplanation structure.

CounterfactualExplanations.converged β€” Method
converged(ce::CounterfactualExplanation)

A convenience method to determine if the counterfactual search has converged. The search is considered to have converged only if the counterfactual is valid.

CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
	x::Union{AbstractArray,Int}, target::RawTargetType, data::CounterfactualData, M::Models.AbstractFittedModel, generator::AbstractGenerator;
	Ξ³::AbstractFloat=0.75, max_iter=1000
)

The core function that is used to run counterfactual search for a given factual x, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.

Examples

Generic generator

using CounterfactualExplanations

# Data:
using CounterfactualExplanations.Data
using Random
Random.seed!(1234)
xs, ys = Data.toy_data_linear()
X = hcat(xs...)
counterfactual_data = CounterfactualData(X,ys')

# Model
using CounterfactualExplanations.Models: LogisticModel, probs 
# Logit model:
w = [1.0 1.0] # true coefficients
b = 0
M = LogisticModel(w, [b])

# Randomly selected factual:
x = select_factual(counterfactual_data,rand(1:size(X)[2]))
y = round(probs(M, x)[1])
target = round(probs(M, x)[1])==0 ? 1 : 0 

# Counterfactual search:
generator = GenericGenerator()
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Base.Iterators.Zip,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractFittedModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a zip of factuals x and return a vector of counterfactuals.

CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Vector{<:Matrix},
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractFittedModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a vector of factuals x and return a vector of counterfactuals.

CounterfactualExplanations.parallelize β€” Method
parallelize(
    parallelizer::nothing,
    f::Function,
    args...;
    kwargs...,
)

If no AbstractParallelizer has been supplied, just call or broadcast the function.

CounterfactualExplanations.target_probs β€” Function
target_probs(
    ce::CounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Returns the predicted probability of the target class for x. If x is nothing, the predicted probability corresponding to the counterfactual value is returned.

CounterfactualExplanations.update! β€” Method
update!(ce::CounterfactualExplanation)

An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.

CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    data::CounterfactualData;
    models::Dict{<:Any,<:Any}=standard_models_catalogue,
    generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    n_individuals::Int=5,
    suppress_training::Bool=false,
    factual::Union{Nothing,RawTargetType}=nothing,
    target::Union{Nothing,RawTargetType}=nothing,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

Runs the benchmarking exercise as follows:

  1. Randomly choose a factual and target label unless specified.
  2. If no pretrained models are provided, it is assumed that a dictionary of callable model objects is provided (by default using the standard_models_catalogue).
  3. Each of these models is then trained on the data.
  4. For each model separately choose n_individuals randomly from the non-target (factual) class. For each generator create a benchmark as in benchmark(x::Union{AbstractArray,Base.Iterators.Zip},...).
  5. Finally, concatenate the results.

If vertical_splits is specified to an integer, the computations are split vertically into vertical_splits chunks. In this case, the results are stored in a temporary directory and concatenated afterwards.

CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    x::Union{AbstractArray,Base.Iterators.Zip},
    target::RawTargetType,
    data::CounterfactualData;
    models::Dict{<:Any,<:AbstractFittedModel},
    generators::Dict{<:Any,<:AbstractGenerator},
    measure::Union{Function,Vector{Function}}=default_measures,
    xids::Union{Nothing,AbstractArray}=nothing,
    dataname::Union{Nothing,Symbol,String}=nothing,
    verbose::Bool=true,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

First generates counterfactual explanations for factual x, the target and data using each of the provided models and generators. Then generates a Benchmark for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation}).

CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    counterfactual_explanations::Vector{CounterfactualExplanation};
    meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    store_ce::Bool=false,
)

Generates a Benchmark for a vector of counterfactual explanations. Optionally meta_data describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data is supplied, it will be automatically inferred. All measure functions are applied to each counterfactual explanation. If store_ce=true, the counterfactual explanations are stored in the benchmark.

CounterfactualExplanations.Evaluation.evaluate β€” Function
evaluate(
    ce::CounterfactualExplanation;
    measure::Union{Function,Vector{Function}}=default_measures,
    agg::Function=mean,
    report_each::Bool=false,
    output_format::Symbol=:Vector,
    pivot_longer::Bool=true
)

Just computes evaluation measures for the counterfactual explanation. By default, no meta data is reported. For report_meta=true, meta data is automatically inferred, unless this overwritten by meta_data. The optional meta_data argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.

CounterfactualExplanations.Evaluation.validity β€” Method
validity(ce::CounterfactualExplanation; Ξ³=0.5)

Checks of the counterfactual search has been successful with respect to the probability threshold Ξ³. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.

CounterfactualExplanations.Data.load_cifar_10 β€” Function
load_cifar_10(n::Union{Nothing, Int}=nothing)

Loads and preprocesses data from the CIFAR-10 dataset for use in counterfactual explanations.

Arguments

  • n::Union{Nothing, Int}=nothing: The number of samples to subsample from the dataset. If n is not specified, all samples will be used.

Returns

  • counterfactual_data::CounterfactualData: A CounterfactualData object containing the preprocessed data.

Example

data = loadcifar10(1000) # loads and preprocesses 1000 samples from the CIFAR-10 dataset

CounterfactualExplanations.Data.load_cifar_10_test β€” Method
load_cifar_10_test()

Loads and preprocesses test data from the CIFAR-10 dataset for use in counterfactual explanations.

Returns

  • counterfactual_data::CounterfactualData: A CounterfactualData object containing the preprocessed test data.

Example

testdata = loadcifar10test() # loads and preprocesses test data from the CIFAR-10 dataset

CounterfactualExplanations.Data.load_german_credit β€” Function
load_german_credit(n::Union{Nothing, Int}=nothing)

Loads and pre-processes UCI German Credit data.

Arguments

  • n::Union{Nothing, Int}=nothing: The number of samples to subsample from the dataset. If n is not specified, all samples will be used. Must be <= 1000 and >= 1.

Returns

  • counterfactual_data::CounterfactualData: A CounterfactualData object containing the preprocessed data.

Example

data = loadgermancredit(500) # loads and preprocesses 500 samples from the German Credit dataset

CounterfactualExplanations.Data.load_uci_adult β€” Function
load_uci_adult(n::Union{Nothing, Int}=1000)

Load and preprocesses data from the UCI 'Adult' dataset

Arguments

  • n::Union{Nothing, Int}=nothing: The number of samples to subsample from the dataset.

Returns

  • counterfactual_data::CounterfactualData: A CounterfactualData object containing the preprocessed data.

Example

data = loaduciadult(20) # loads and preprocesses 20 samples from the Adult dataset

CounterfactualExplanations.DataPreprocessing.CounterfactualData β€” Method
CounterfactualData(
    X::AbstractMatrix, y::AbstractMatrix;
    mutability::Union{Vector{Symbol},Nothing}=nothing,
    domain::Union{Any,Nothing}=nothing,
    features_categorical::Union{Vector{Int},Nothing}=nothing,
    features_continuous::Union{Vector{Int},Nothing}=nothing,
    standardize::Bool=false
)

This outer constructor method prepares features X and labels y to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.

Examples

using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
CounterfactualExplanations.DataPreprocessing.CounterfactualData β€” Method
function CounterfactualData(
    X::Tables.MatrixTable,
    y::RawOutputArrayType;
    kwrgs...
)

Outer constructor method that accepts a Tables.MatrixTable. By default, the indices of categorical and continuous features are automatically inferred the features' scitype.

CounterfactualExplanations.Models.EvoTreeModel β€” Type
EvoTreeModel <: AbstractMLJModel

Constructor for gradient-boosted decision trees from the EvoTrees.jl library.

Arguments

  • model::Any: The model selected by the user. Must be a model from the MLJ library.
  • likelihood::Symbol: The likelihood of the model. Must be one of [:classification_binary, :classification_multi].

Returns

  • EvoTreeModel: An EvoTreeClassifier from EvoTrees.jl wrapped inside the EvoTreeModel class.
CounterfactualExplanations.Models.EvoTreeModel β€” Method
EvoTreeModel(data::CounterfactualData; kwargs...)

Constructs a new EvoTreeModel object from the data in a CounterfactualData object. Not called by the user directly.

Arguments

  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • model::EvoTreeModel: The EvoTree model.
CounterfactualExplanations.Models.DecisionTreeModel β€” Method
DecisionTreeModel(data::CounterfactualData; kwargs...)

Constructs a new TreeModel object wrapped around a decision tree from the data in a CounterfactualData object. Not called by the user directly.

Arguments

  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • model::TreeModel: A TreeModel object.
CounterfactualExplanations.Models.Linear β€” Method
Linear(data::CounterfactualData; kwargs...)

Constructs a model with one linear layer. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.

CounterfactualExplanations.Models.RandomForestModel β€” Method
RandomForestModel(data::CounterfactualData; kwargs...)

Constructs a new TreeModel object wrapped around a random forest from the data in a CounterfactualData object. Not called by the user directly.

Arguments

  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • model::TreeModel: A TreeModel object.
CounterfactualExplanations.Models.fit_model β€” Function
fit_model(
    counterfactual_data::CounterfactualData, model::Symbol=:MLP;
    kwrgs...
)

Fits one of the available default models to the counterfactual_data. The model argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue dictionary.

CounterfactualExplanations.Models.logits β€” Method
logits(M::AbstractFittedModel, X::AbstractArray)

Generic method that is compulsory for all models. It returns the raw model predictions. In classification this is sometimes referred to as logits: the non-normalized predictions that are fed into a link function to produce predicted probabilities. In regression (not currently implemented) raw outputs typically correspond to final outputs. In other words, there is typically no normalization involved.

CounterfactualExplanations.Models.logits β€” Method
logits(M::EvoTreeModel, X::AbstractArray)

Calculates the logit scores output by the model M for the input data X.

Arguments

  • M::EvoTreeModel: The model selected by the user. Must be a model from the MLJ library.
  • X::AbstractArray: The feature vector for which the logit scores are calculated.

Returns

  • logits::Matrix: A matrix of logits for each output class for each data point in X.

Example

logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x

CounterfactualExplanations.Models.logits β€” Method
logits(M::TreeModel, X::AbstractArray)

Calculates the logit scores output by the model M for the input data X.

Arguments

  • M::TreeModel: The model selected by the user.
  • X::AbstractArray: The feature vector for which the logit scores are calculated.

Returns

  • logits::Matrix: A matrix of logits for each output class for each data point in X.

Example

logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x

CounterfactualExplanations.Models.predict_label β€” Method
predict_label(M::AbstractFittedModel, counterfactual_data::CounterfactualData, X::AbstractArray)

Returns the predicted output label for a given model M, data set counterfactual_data and input data X.

CounterfactualExplanations.Models.predict_label β€” Method
predict_label(M::AbstractFittedModel, counterfactual_data::CounterfactualData)

Returns the predicted output labels for all data points of data set counterfactual_data for a given model M.

CounterfactualExplanations.Models.predict_label β€” Method
predict_label(M::TreeModel, X::AbstractArray)

Returns the predicted label for X.

Arguments

  • M::TreeModel: The model selected by the user.
  • X::AbstractArray: The input array for which the label is predicted.

Returns

  • labels::AbstractArray: The predicted label for each data point in X.

Example

label = Models.predict_label(M, x) # returns the predicted label for each data point in x

CounterfactualExplanations.Models.predict_proba β€” Method
predict_proba(M::AbstractFittedModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})

Returns the predicted output probabilities for a given model M, data set counterfactual_data and input data X.

CounterfactualExplanations.Models.probs β€” Method
probs(M::AbstractFittedModel, X::AbstractArray)

Generic method that is compulsory for all models. It returns the normalized model predictions, so the predicted probabilities in the case of classification. In regression (not currently implemented) this method is redundant.

CounterfactualExplanations.Models.probs β€” Method
probs(M::EvoTreeModel, X::AbstractArray{<:Number, 2})

Calculates the probability scores for each output class for the two-dimensional input data matrix X.

Arguments

  • M::EvoTreeModel: The EvoTree model.
  • X::AbstractArray: The feature vector for which the predictions are made.

Returns

  • p::Matrix: A matrix of probability scores for each output class for each data point in X.

Example

probabilities = Models.probs(M, X) # calculates the probability scores for each output class for each data point in X.

CounterfactualExplanations.Models.probs β€” Method
probs(M::EvoTreeModel, X::AbstractArray{<:Number, 1})

Works the same way as the probs(M::EvoTreeModel, X::AbstractArray{<:Number, 2}) method above, but handles 1-dimensional rather than 2-dimensional input data.

CounterfactualExplanations.Models.probs β€” Method
probs(M::TreeModel, X::AbstractArray{<:Number, 2})

Calculates the probability scores for each output class for the two-dimensional input data matrix X.

Arguments

  • M::TreeModel: The TreeModel.
  • X::AbstractArray: The feature vector for which the predictions are made.

Returns

  • p::Matrix: A matrix of probability scores for each output class for each data point in X.

Example

probabilities = Models.probs(M, X) # calculates the probability scores for each output class for each data point in X.

CounterfactualExplanations.Models.probs β€” Method
probs(M::TreeModel, X::AbstractArray{<:Number, 1})

Works the same way as the probs(M::TreeModel, X::AbstractArray{<:Number, 2}) method above, but handles 1-dimensional rather than 2-dimensional input data.

CounterfactualExplanations.Generators.FeatureTweakGenerator β€” Type
FeatureTweakGenerator(Ο΅::AbstractFloat=0.1)

Constructs a new Feature Tweak Generator object.

Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei er al., an alternative choice here would be using the L0-norm to simply minimize the number of features that are changed through the tweak.

Arguments

  • Ο΅::AbstractFloat: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf).

Returns

  • generator::FeatureTweakGenerator: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
CounterfactualExplanations.Generators.GradientBasedGenerator β€” Method
GradientBasedGenerator(;
	loss::Union{Nothing,Function}=nothing,
	penalty::Penalty=nothing,
	Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
	latent_space::Bool::false,
	opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
)

Default outer constructor for GradientBasedGenerator.

Arguments

  • loss::Union{Nothing,Function}=nothing: The loss function used by the model.
  • penalty::Penalty=nothing: A penalty function for the generator to penalize counterfactuals too far from the original point.
  • Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing: The weight of the penalty function.
  • latent_space::Bool=false: Whether to use the latent space of a generative model to generate counterfactuals.
  • opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(): The optimizer to use for the generator.

Returns

  • generator::GradientBasedGenerator: A gradient-based counterfactual generator.
CounterfactualExplanations.Generators.CLUEGenerator β€” Method
CLUEGenerator(
    ;
    loss::Symbol=:logitbinarycrossentropy,
    complexity::Function=norm,
    Ξ»::AbstractFloat=0.1,
    Ο΅::AbstractFloat=0.1,
    Ο„::AbstractFloat=1e-5
)

An outer constructor method that instantiates a CLUE generator.

Examples

generator = CLUEGenerator()
CounterfactualExplanations.Generators.ProbeGenerator β€” Method
ProbeGenerator(; Ξ»::AbstractFloat=0.1, loss::Symbol=:mse, penalty=distance_l1, kwargs...)

Create a generator that generates counterfactual probes using the specified loss function and penalty function.

Arguments

  • Ξ»::AbstractFloat: The regularization parameter for the generator.
  • loss::Symbol: The loss function to use for the generator. Defaults to :mse.
  • penalty: The penalty function to use for the generator. Defaults to distance_l1.
  • kwargs: Additional keyword arguments to pass to the Generator constructor.

Returns

A Generator object that can be used to generate counterfactual probes.

based on https://arxiv.org/abs/2203.06768

CounterfactualExplanations.Generators.conditions_satisfied β€” Method
conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.

CounterfactualExplanations.Generators.feature_tweaking β€” Method
feature_tweaking(generator::FeatureTweakGenerator, ensemble::FluxEnsemble, x::AbstractArray, target::RawTargetType)

Returns a counterfactual instance of x based on the ensemble of classifiers provided.

Arguments

  • generator::FeatureTweakGenerator: The feature tweak generator.
  • M::Models.TreeModel: The model for which the counterfactual is generated. Must be a tree-based model.
  • x::AbstractArray: The factual instance.
  • target::RawTargetType: The target class.

Returns

  • x_out::AbstractArray: The counterfactual instance.

Example

x = feature_tweaking(generator, M, x, target) # returns a counterfactual instance of x based on the ensemble of classifiers provided

CounterfactualExplanations.Generators.hinge_loss β€” Method
hinge_loss(ce::AbstractCounterfactualExplanation)

Calculate the hinge loss of a counterfactual explanation.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the hinge loss for.

Returns

The hinge loss of the counterfactual explanation.

CounterfactualExplanations.Generators.invalidation_rate β€” Method
invalidation_rate(ce::AbstractCounterfactualExplanation)

Calculate the invalidation rate of a counterfactual explanation.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the invalidation rate for.
  • kwargs: Additional keyword arguments to pass to the function.

Returns

The invalidation rate of the counterfactual explanation.

CounterfactualExplanations.Generators.mutability_constraints β€” Method
mutability_constraints(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to return mutability constraints that are dependent on the current counterfactual search state. For generic gradient-based generators, no state-dependent constraints are added.

Flux.Losses.logitbinarycrossentropy β€” Method
Flux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitbinarycrossentropy method to work with objects of type AbstractCounterfactualExplanation.

Flux.Losses.logitcrossentropy β€” Method
Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitcrossentropy method to work with objects of type AbstractCounterfactualExplanation.

Flux.Losses.mse β€” Method
Flux.Losses.mse(ce::AbstractCounterfactualExplanation)

Simply extends the mse method to work with objects of type AbstractCounterfactualExplanation.

Internal functions

CounterfactualExplanations.apply_mutability β€” Method
apply_mutability(
    ce::CounterfactualExplanation,
    Ξ”sβ€²::AbstractArray,
)

A subroutine that applies mutability constraints to the proposed vector of feature perturbations.

CounterfactualExplanations.decode_array β€” Method
decode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to decode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

CounterfactualExplanations.decode_array β€” Method
decode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)

Helper function to decode an array x using a data transform dt::StatsBase.AbstractDataTransform.

CounterfactualExplanations.decode_state β€” Function

function decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )

Applies all the applicable decoding functions:

  1. If applicable, map the state variable back from the latent space to the feature space.
  2. If and where applicable, inverse-transform features.
  3. Reconstruct all categorical encodings.

Finally, the decoded counterfactual is returned.

CounterfactualExplanations.encode_array β€” Method
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to encode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

CounterfactualExplanations.encode_array β€” Method
encode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)

Helper function to encode an array x using a data transform dt::StatsBase.AbstractDataTransform.

CounterfactualExplanations.encode_state β€” Function

function encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )

Applies all required encodings to x:

  1. If applicable, it maps x to the latent space learned by the generative model.
  2. If and where applicable, it rescales features.

Finally, it returns the encoded state variable.

CounterfactualExplanations.guess_likelihood β€” Method
guess_likelihood(y::RawOutputArrayType)

Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.

CounterfactualExplanations.initialize_state β€” Method
initialize_state(ce::CounterfactualExplanation)

Initializes the starting point for the factual(s):

  1. If ce.initialization is set to :identity or counterfactuals are searched in a latent space, then nothing is done.
  2. If ce.initialization is set to :add_perturbation, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
CounterfactualExplanations.map_from_latent β€” Function
map_from_latent(
    ce::CounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Maps the state variable back from the latent space to the feature space.

CounterfactualExplanations.map_to_latent β€” Function

function maptolatent( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )

Maps x from the feature space $\mathcal{X}$ to the latent space learned by the generative model.

CounterfactualExplanations.threshold_reached β€” Method
threshold_reached(ce::CounterfactualExplanation, x::AbstractArray)

A convenience method that determines if the predefined threshold for the target class probability has been reached for a specific sample x.

CounterfactualExplanations.threshold_reached β€” Method
threshold_reached(ce::CounterfactualExplanation)

A convenience method that determines if the predefined threshold for the target class probability has been reached.

CounterfactualExplanations.wants_latent_space β€” Method
wants_latent_space(
    ce::CounterfactualExplanation, 
    x::Union{AbstractArray,Nothing} = nothing,
)

A convenience function that checks if latent space search is applicable.

Base.vcat β€” Method
Base.vcat(bmk1::Benchmark, bmk2::Benchmark)

Vertically concatenates two Benchmark objects.

CounterfactualExplanations.Evaluation.compute_measure β€” Method
compute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)

Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce and aggregated using the aggregation function agg.

CounterfactualExplanations.Evaluation.to_dataframe β€” Method
evaluate_dataframe(
    ce::CounterfactualExplanation,
    measure::Vector{Function},
    agg::Function,
    report_each::Bool,
    pivot_longer::Bool,
    store_ce::Bool,
)

Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.

CounterfactualExplanations.Evaluation.validity_strict β€” Method
validity_strict(ce::CounterfactualExplanation)

Checks if the counterfactual search has been strictly valid in the sense that it has converged with respect to the pre-specified target probability Ξ³.

CounterfactualExplanations.DataPreprocessing.convert_to_1d β€” Method
convert_to_1d(y::Matrix, y_levels::AbstractArray)

Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.

Arguments

  • y::Matrix: The one-hot encoded matrix.
  • y_levels::AbstractArray: The levels of the categorical variable.

Returns

  • labels: A vector of labels.
CounterfactualExplanations.DataPreprocessing.get_generative_model β€” Method
get_generative_model(counterfactual_data::CounterfactualData)

Returns the underlying generative model. If there is no existing model available, the default generative model (VAE) is used. Otherwise it is expected that existing generative model has been pre-trained or else a warning is triggered.

CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj β€” Method
preprocess_data_for_mlj(data::CounterfactualData)

Helper function to preprocess data::CounterfactualData for MLJ models.

Arguments

  • data::CounterfactualData: The data to be preprocessed.

Returns

  • (df_x, y): A tuple containing the preprocessed data, with df_x being a DataFrame object and y being a categorical vector.

Example

X, y = preprocessdatafor_mlj(data)

CounterfactualExplanations.Models.TreeModel β€” Type
TreeModel <: AbstractNonDifferentiableJuliaModel

Constructor for tree-based models from the MLJ library.

Arguments

  • model::Any: The model selected by the user. Must be a model from the MLJ library.
  • likelihood::Symbol: The likelihood of the model. Must be one of [:classification_binary, :classification_multi].

Returns

  • TreeModel: A tree-based model from the MLJ library wrapped inside the TreeModel class.
CounterfactualExplanations.Models.get_individual_classifiers β€” Method
get_individual_classifiers(M::TreeModel)

Returns the individual classifiers in the forest. If the input is a decision tree, the method returns the decision tree itself inside an array.

Arguments

  • M::TreeModel: The model selected by the user.

Returns

  • classifiers::AbstractArray: An array of individual classifiers in the forest.

Example

classifiers = Models.getindividualclassifiers(M) # returns the individual classifiers in the forest

CounterfactualExplanations.Models.train β€” Method
train(M::EvoTreeModel, data::CounterfactualData; kwargs...)

Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.

Arguments

  • M::EvoTreeModel: The wrapper for an EvoTree model.
  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • M::EvoTreeModel: The fitted EvoTree model.
CounterfactualExplanations.Models.train β€” Method
train(M::TreeModel, data::CounterfactualData; kwargs...)

Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.

Arguments

  • M::TreeModel: The wrapper for a TreeModel.
  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • M::TreeModel: The fitted TreeModel.
CounterfactualExplanations.Generators.Generator β€” Method
Generator(;
	loss::Union{Nothing,Function}=nothing,
	penalty::Penalty=nothing,
	Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
	latent_space::Bool::false,
	opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
)

An outer constructor that allows for more convenient creation of the GradientBasedGenerator type.

CounterfactualExplanations.Generators.converged β€” Method
converged(ce::AbstractCounterfactualExplanation)

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation object.

Returns

  • converged::Bool:

Finds if we have converged.

CounterfactualExplanations.Generators.esatisfactory_instance β€” Method
esatisfactory_instance(generator::FeatureTweakGenerator, x::AbstractArray, paths::Dict{String, Dict{String, Any}})

Returns an epsilon-satisfactory counterfactual for x based on the paths provided.

Arguments

  • generator::FeatureTweakGenerator: The feature tweak generator.
  • x::AbstractArray: The factual instance.
  • paths::Dict{String, Dict{String, Any}}: A list of paths to the leaves of the tree to be used for tweaking the feature.

Returns

  • esatisfactory::AbstractArray: The epsilon-satisfactory instance.

Example

esatisfactory = esatisfactory_instance(generator, x, paths) # returns an epsilon-satisfactory counterfactual for x based on the paths provided

CounterfactualExplanations.Generators.feature_selection! β€” Method
feature_selection!(ce::AbstractCounterfactualExplanation)

Perform feature selection to find the dimension with the closest (but not equal) values between the ce.x (factual) and ce.sβ€² (counterfactual) arrays.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

The function iteratively modifies the ce.sβ€² counterfactual array by updating its elements to match the corresponding elements in the ce.x factual array, one dimension at a time, until the predicted label of the modified ce.sβ€² matches the predicted label of the ce.x array.

CounterfactualExplanations.Generators.find_closest_dimension β€” Method
find_closest_dimension(factual, counterfactual)

Find the dimension with the closest (but not equal) values between the factual and counterfactual arrays.

Arguments

  • factual: The factual array.
  • counterfactual: The counterfactual array.

Returns

  • closest_dimension: The index of the dimension with the closest values.

The function iterates over the indices of the factual array and calculates the absolute difference between the corresponding elements in the factual and counterfactual arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in factual and counterfactual are equal.

CounterfactualExplanations.Generators.find_counterfactual β€” Method
find_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)

Find the first counterfactual index by predicting labels.

Arguments

  • model: The fitted model used for prediction.
  • target_class: Expected target class.
  • counterfactual_data: Data required for counterfactual generation.
  • counterfactual_candidates: The array of counterfactual candidates.

Returns

  • counterfactual: The index of the first counterfactual found.
CounterfactualExplanations.Generators.growing_spheres_generation! β€” Method
growing_spheres_generation(ce::AbstractCounterfactualExplanation)

Generate counterfactual candidates using the growing spheres generation algorithm.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.

The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in ce.M. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.

CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided.

CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Nothing, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where no penalty is provided.

CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

CounterfactualExplanations.Generators.hyper_sphere_coordinates β€” Method
hyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)

Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.

The implementation follows the Random Point Picking over a sphere algorithm described in the paper: "Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020), presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html

The growing spheres method is originally proposed in the paper: "Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018), presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).

Arguments

  • n_search_samples::Int: The number of search samples (int > 0).
  • instance::AbstractArray: The input point array.
  • low::AbstractFloat: The lower bound (float >= 0, l < h).
  • high::AbstractFloat: The upper bound (float >= 0, h > l).
  • p_norm::Integer: The norm parameter (int >= 1).

Returns

  • candidate_counterfactuals::Array: An array of candidate counterfactuals.
CounterfactualExplanations.Generators.search_path β€” Function
search_path(tree::Union{Leaf, Node}, target::RawTargetType, path::AbstractArray)

Return a path index list with the inequality symbols, thresholds and feature indices.

Arguments

  • tree::Union{Leaf, Node}: The root node of a decision tree.
  • target::RawTargetType: The target class.
  • path::AbstractArray: A list containing the paths found thus far.

Returns

  • paths::AbstractArray: A list of paths to the leaves of the tree to be used for tweaking the feature.

Example

paths = search_path(tree, target) # returns a list of paths to the leaves of the tree to be used for tweaking the feature

CounterfactualExplanations.Generators.β„“ β€” Method
β„“(generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation)

Overloads the β„“ function for the case where a single loss function is provided.

CounterfactualExplanations.Generators.β„“ β€” Method
β„“(generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation)

Overloads the β„“ function for the case where no loss function is provided.

CounterfactualExplanations.Generators.βˆ‚h β€” Method
βˆ‚h(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

CounterfactualExplanations.Generators.βˆ‚β„“ β€” Method
βˆ‚β„“(generator::AbstractGradientBasedGenerator, M::Union{Models.LogisticModel, Models.BayesianLogisticModel}, ce::AbstractCounterfactualExplanation)

The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

CounterfactualExplanations.Generators.βˆ‡ β€” Method
βˆ‡(generator::AbstractGradientBasedGenerator, M::Models.AbstractDifferentiableModel, ce::AbstractCounterfactualExplanation)

The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivates. It assumes that Zygote.jl has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.