CounterfactualExplanations.Models.ModelMethod
Model(model, type::AbstractModelType; likelihood::Symbol=:classification_binary)

Outer constructor for Model where the atomic model is defined and assumed to be pre-trained.

CounterfactualExplanations.Models.ModelMethod
(M::Model)(data::CounterfactualData, type::Linear; kwargs...)

Constructs a model with one linear layer for the given data. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.

CounterfactualExplanations.Models.fit_modelFunction
fit_model(
    counterfactual_data::CounterfactualData, model::Symbol=:MLP;
    kwrgs...
)

Fits one of the available default models to the counterfactual_data. The model argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue dictionary.

CounterfactualExplanations.Models.fit_modelMethod
fit_model(
    counterfactual_data::CounterfactualData, type::AbstractModelType; kwrgs...
)

A wrapper function to fit a model to the counterfactual_data for a given type of model.

Arguments

  • counterfactual_data::CounterfactualData: The data to be used for training the model.
  • type::AbstractModelType: The type of model to be trained, e.g., MLP, DecisionTreeModel, etc.

Examples

julia> using CounterfactualExplanations

julia> using CounterfactualExplanations.Models

julia> using TaijaData

julia> data = CounterfactualData(load_linearly_separable()...);

julia> M = fit_model(data, Linear())
CounterfactualExplanations.Models.Model(Chain(Dense(2 => 2)), :classification_multi, Chain(Dense(2 => 2)), Linear())
CounterfactualExplanations.Models.model_evaluationMethod
model_evaluation(M::AbstractModel, test_data::CounterfactualData)

Helper function to compute F-Score for AbstractModel on a (test) data set. By default, it computes the accuracy. Any other measure, e.g. from the StatisticalMeasures package, can be passed as an argument. Currently, only measures applicable to classification tasks are supported.

CounterfactualExplanations.Models.predict_labelMethod
predict_label(M::AbstractModel, counterfactual_data::CounterfactualData, X::AbstractArray)

Returns the predicted output label for a given model M, data set counterfactual_data and input data X.

CounterfactualExplanations.Models.predict_labelMethod
predict_label(M::AbstractModel, counterfactual_data::CounterfactualData)

Returns the predicted output labels for all data points of data set counterfactual_data for a given model M.

CounterfactualExplanations.Models.predict_probaMethod
predict_proba(M::AbstractModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})

Returns the predicted output probabilities for a given model M, data set counterfactual_data and input data X.

CounterfactualExplanations.Models.probsMethod
probs(
    M::Model,
    type::MLJModelType,
    X::AbstractArray,
)

Overloads the probs method for MLJ models.

Note for developers

Note that currently the underlying MLJ methods (reformat, predict) are incompatible with Zygote's autodiff. For differentiable MLJ models, the probs` and logits methods need to be overloaded.

Flux.Losses.logitbinarycrossentropyMethod
Flux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitbinarycrossentropy method to work with objects of type AbstractCounterfactualExplanation.

Flux.Losses.logitcrossentropyMethod
Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitcrossentropy method to work with objects of type AbstractCounterfactualExplanation.

Flux.Losses.mseMethod
Flux.Losses.mse(ce::AbstractCounterfactualExplanation)

Simply extends the mse method to work with objects of type AbstractCounterfactualExplanation.

CounterfactualExplanations.DataPreprocessing.InputTransformerType
InputTransformer

Abstract type for data transformers. This can be any of the following:

  • StatsBase.AbstractDataTransform: A data transformation object from the StatsBase package.
  • MultivariateStats.AbstractDimensionalityReduction: A dimensionality reduction object from the MultivariateStats package.
  • GenerativeModels.AbstractGenerativeModel: A generative model object from the GenerativeModels module.
CounterfactualExplanations.DataPreprocessing.CounterfactualDataMethod
CounterfactualData(
    X::AbstractMatrix,
    y::RawOutputArrayType;
    mutability::Union{Vector{Symbol},Nothing}=nothing,
    domain::Union{Any,Nothing}=nothing,
    features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
    features_continuous::Union{Vector{Int},Nothing}=nothing,
    input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)

This outer constructor method prepares features X and labels y to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.

Examples

using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
CounterfactualExplanations.DataPreprocessing.CounterfactualDataMethod
function CounterfactualData(
    X::Tables.MatrixTable,
    y::RawOutputArrayType;
    kwrgs...
)

Outer constructor method that accepts a Tables.MatrixTable. By default, the indices of categorical and continuous features are automatically inferred the features' scitype.

CounterfactualExplanations.DataPreprocessing.convert_to_1dMethod
convert_to_1d(y::Matrix, y_levels::AbstractArray)

Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.

Arguments

  • y::Matrix: The one-hot encoded matrix.
  • y_levels::AbstractArray: The levels of the categorical variable.

Returns

  • labels: A vector of labels.
CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mljMethod
preprocess_data_for_mlj(data::CounterfactualData)

Helper function to preprocess data::CounterfactualData for MLJ models.

Arguments

  • data::CounterfactualData: The data to be preprocessed.

Returns

  • (df_x, y): A tuple containing the preprocessed data, with df_x being a DataFrame object and y being a categorical vector.

Example

X, y = preprocessdatafor_mlj(data)

CounterfactualExplanations.DataPreprocessing.train_test_splitMethod
train_test_split(data::CounterfactualData;test_size=0.2,keep_class_ratio=false)

Splits data into train and test split.

Arguments

  • data::CounterfactualData: The data to be preprocessed.
  • test_size=0.2: Proportion of the data to be used for testing.
  • keep_class_ratio=false: Decides whether to sample equally from each class, or keep their relative size.

Returns

  • (train_data::CounterfactualData, test_data::CounterfactualData): A tuple containing the train and test splits.

Example

train, test = traintestsplit(data, testsize=0.1, keepclass_ratio=true)

CounterfactualExplanations.CounterfactualExplanationMethod
function CounterfactualExplanation(;
	x::AbstractArray,
	target::RawTargetType,
	data::CounterfactualData,
	M::Models.AbstractModel,
	generator::Generators.AbstractGenerator,
	num_counterfactuals::Int = 1,
	initialization::Symbol = :add_perturbation,
    convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)

Outer method to construct a CounterfactualExplanation structure.

CounterfactualExplanations.LaplaceReduxModelType
LaplaceReduxModel

Concrete type for neural networks with Laplace Approximation from the LaplaceRedux package. Currently subtyping the AbstractFluxNN model type, although this may be changed to MLJ in the future.

CounterfactualExplanations.NeuroTreeModelType
NeuroTreeModel

Concrete type for differentiable tree-based models from NeuroTreeModels. Since NeuroTreeModels has an MLJ interface, we subtype the MLJModelType model type.

CounterfactualExplanations.RandomForestModelType
RandomForestModel

Concrete type for random forest model from DecisionTree.jl. Since the DecisionTree package has an MLJ interface, we subtype the MLJModelType model type.

CounterfactualExplanations.apply_mutabilityMethod
apply_mutability(
    ce::CounterfactualExplanation,
    Δs′::AbstractArray,
)

A subroutine that applies mutability constraints to the proposed vector of feature perturbations.

CounterfactualExplanations.decode_arrayMethod
decode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)

Helper function to decode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.

CounterfactualExplanations.decode_arrayMethod
decode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to decode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

CounterfactualExplanations.decode_arrayMethod
decode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)

Helper function to decode an array x using a data transform dt::StatsBase.AbstractDataTransform.

CounterfactualExplanations.decode_stateFunction

function decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )

Applies all the applicable decoding functions:

  1. If applicable, map the state variable back from the latent space to the feature space.
  2. If and where applicable, inverse-transform features.
  3. Reconstruct all categorical encodings.

Finally, the decoded counterfactual is returned.

CounterfactualExplanations.encode_arrayMethod
encode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)

Helper function to encode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.

CounterfactualExplanations.encode_arrayMethod
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to encode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

CounterfactualExplanations.encode_arrayMethod
encode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)

Helper function to encode an array x using a data transform dt::StatsBase.AbstractDataTransform.

CounterfactualExplanations.encode_stateFunction

function encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )

Applies all required encodings to x:

  1. If applicable, it maps x to the latent space learned by the generative model.
  2. If and where applicable, it rescales features.

Finally, it returns the encoded state variable.

CounterfactualExplanations.generate_counterfactualMethod
generate_counterfactual(
    x::Base.Iterators.Zip,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a zip of factuals x and return a vector of counterfactuals.

CounterfactualExplanations.generate_counterfactualMethod
generate_counterfactual(
    x::Matrix,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    num_counterfactuals::Int=1,
    initialization::Symbol=:add_perturbation,
    convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
    timeout::Union{Nothing,Real}=nothing,
)

The core function that is used to run counterfactual search for a given factual x, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.

Arguments

  • x::Matrix: Factual data point.
  • target::RawTargetType: Target class.
  • data::CounterfactualData: Counterfactual data.
  • M::Models.AbstractModel: Fitted model.
  • generator::AbstractGenerator: Generator.
  • num_counterfactuals::Int=1: Number of counterfactuals to generate for factual.
  • initialization::Symbol=:add_perturbation: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.
  • convergence::Union{AbstractConvergence,Symbol}=:decision_threshold: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are :decision_threshold, :max_iter, :generator_conditions or a conrete convergence object (e.g. DecisionThresholdConvergence).
  • timeout::Union{Nothing,Int}=nothing: Timeout in seconds.

Examples

Generic generator

julia> using CounterfactualExplanations

julia> using TaijaData
       
        # Counteractual data and model:

julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);

julia> M = fit_model(counterfactual_data, :Linear);

julia> target = 2;

julia> factual = 1;

julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual));

julia> x = select_factual(counterfactual_data, chosen);
       
       # Search:

julia> generator = Generators.GenericGenerator();

julia> ce = generate_counterfactual(x, target, counterfactual_data, M, generator);

julia> converged(ce.convergence, ce)
true

Broadcasting

The generate_counterfactual method can also be broadcasted over a tuple containing an array. This allows for generating multiple counterfactuals in parallel.

julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 5);

julia> xs = select_factual(counterfactual_data, chosen);

julia> ces = generate_counterfactual.(xs, target, counterfactual_data, M, generator);

julia> converged(ce.convergence, ce)
true
CounterfactualExplanations.generate_counterfactualMethod
generate_counterfactual(
    x::Matrix,
    target::RawTargetType,
    data::DataPreprocessing.CounterfactualData,
    M::Models.AbstractModel,
    generator::Generators.GrowingSpheresGenerator;
    num_counterfactuals::Int=1,
    convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
        decision_threshold=(1 / length(data.y_levels)), max_iter=1000
    ),
    kwrgs...,
)

Overloads the generate_counterfactual for the GrowingSpheresGenerator generator.

CounterfactualExplanations.generate_counterfactualMethod
generate_counterfactual(x::Tuple{<:AbstractArray}, args...; kwargs...)

Overloads the generate_counterfactual method to accept a tuple containing and array. This allows for broadcasting over Zip iterators.

CounterfactualExplanations.generate_counterfactualMethod
generate_counterfactual(
    x::Vector{<:Matrix},
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a vector of factuals x and return a vector of counterfactuals.

CounterfactualExplanations.guess_likelihoodMethod
guess_likelihood(y::RawOutputArrayType)

Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.

CounterfactualExplanations.initialize!Method
initialize!(ce::CounterfactualExplanation)

Initializes the counterfactual explanation. This method is called by the constructor. It does the following:

  1. Creates a dictionary to store information about the search.
  2. Initializes the counterfactual state.
  3. Initializes the search path.
  4. Initializes the loss.
CounterfactualExplanations.initialize_stateMethod
initialize_state(ce::CounterfactualExplanation)

Initializes the starting point for the factual(s):

  1. If ce.initialization is set to :identity or counterfactuals are searched in a latent space, then nothing is done.
  2. If ce.initialization is set to :add_perturbation, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
CounterfactualExplanations.target_probsFunction
target_probs(
    ce::CounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Returns the predicted probability of the target class for x. If x is nothing, the predicted probability corresponding to the counterfactual value is returned.

CounterfactualExplanations.update!Method
update!(ce::CounterfactualExplanation)

An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.

CounterfactualExplanations.GenerativeModels.VAEType
VAE <: AbstractGenerativeModel

Constructs the Variational Autoencoder. The VAE is a subtype of AbstractGenerativeModel. Any (sub-)type of AbstractGenerativeModel is accepted by latent space generators.

Base.randFunction

Random.rand(encoder::Encoder, x, device=cpu)

Draws random samples from the latent distribution.

CounterfactualExplanations.Convergence.DecisionThresholdConvergenceType
DecisionThresholdConvergence

Convergence criterion based on the target class probability threshold. The search stops when the target class probability exceeds the predefined threshold.

Fields

  • decision_threshold::AbstractFloat: The predefined threshold for the target class probability.
  • max_iter::Int: The maximum number of iterations.
  • min_success_rate::AbstractFloat: The minimum success rate for the target class probability.
CounterfactualExplanations.Convergence.GeneratorConditionsConvergenceType
GeneratorConditionsConvergence

Convergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.

Fields

  • decision_threshold::AbstractFloat: The threshold for the decision probability.
  • gradient_tol::AbstractFloat: The tolerance for the gradients of the search objective.
  • max_iter::Int: The maximum number of iterations.
  • min_success_rate::AbstractFloat: The minimum success rate for the generator conditions (across counterfactuals).
CounterfactualExplanations.Convergence.convergedFunction
converged(
    convergence::InvalidationRateConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.

CounterfactualExplanations.Convergence.convergedFunction
converged(
    convergence::DecisionThresholdConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.

CounterfactualExplanations.Convergence.convergedFunction
converged(
    convergence::MaxIterConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.

CounterfactualExplanations.Convergence.convergedFunction
converged(
    convergence::GeneratorConditionsConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.

CounterfactualExplanations.Convergence.hinge_lossMethod
hinge_loss(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)

Calculates the hinge loss of a counterfactual explanation.

Arguments

  • convergence::InvalidationRateConvergence: The convergence criterion to use.
  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the hinge loss for.

Returns

The hinge loss of the counterfactual explanation.

CounterfactualExplanations.Convergence.invalidation_rateMethod
invalidation_rate(ce::AbstractCounterfactualExplanation)

Calculates the invalidation rate of a counterfactual explanation.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the invalidation rate for.
  • kwargs: Additional keyword arguments to pass to the function.

Returns

The invalidation rate of the counterfactual explanation.

CounterfactualExplanations.Generators.FeatureTweakGeneratorMethod
FeatureTweakGenerator(; penalty::Union{Nothing,Function,Vector{Function}}=Objectives.distance_l2, ϵ::AbstractFloat=0.1)

Constructs a new Feature Tweak Generator object.

Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei et al., another recommended choice for the penalty in addition to the L2-norm is the L0-norm. The L0-norm simply minimizes the number of features that are changed through the tweak.

Arguments

  • penalty::Union{Nothing,Function,Vector{Function}}: The penalty function to use for the generator. Defaults to distance_l2.
  • ϵ::AbstractFloat: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf). Defaults to 0.1.

Returns

  • generator::FeatureTweakGenerator: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
CounterfactualExplanations.Generators.GradientBasedGeneratorMethod
GradientBasedGenerator(;
	loss::Union{Nothing,Function}=nothing,
	penalty::Penalty=nothing,
	λ::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
	latent_space::Bool::false,
	opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
    generative_model_params::NamedTuple=(;),
)

Default outer constructor for GradientBasedGenerator.

Arguments

  • loss::Union{Nothing,Function}=nothing: The loss function used by the model.
  • penalty::Penalty=nothing: A penalty function for the generator to penalize counterfactuals too far from the original point.
  • λ::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing: The weight of the penalty function.
  • latent_space::Bool=false: Whether to use the latent space of a generative model to generate counterfactuals.
  • opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(): The optimizer to use for the generator.
  • generative_model_params::NamedTuple: The parameters of the generative model associated with the generator.

Returns

  • generator::GradientBasedGenerator: A gradient-based counterfactual generator.
CounterfactualExplanations.Generators.conditions_satisfiedMethod
conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.

CounterfactualExplanations.Generators.feature_selection!Method
feature_selection!(ce::AbstractCounterfactualExplanation)

Perform feature selection to find the dimension with the closest (but not equal) values between the ce.x (factual) and ce.s′ (counterfactual) arrays.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

The function iteratively modifies the ce.s′ counterfactual array by updating its elements to match the corresponding elements in the ce.x factual array, one dimension at a time, until the predicted label of the modified ce.s′ matches the predicted label of the ce.x array.

CounterfactualExplanations.Generators.find_closest_dimensionMethod
find_closest_dimension(factual, counterfactual)

Find the dimension with the closest (but not equal) values between the factual and counterfactual arrays.

Arguments

  • factual: The factual array.
  • counterfactual: The counterfactual array.

Returns

  • closest_dimension: The index of the dimension with the closest values.

The function iterates over the indices of the factual array and calculates the absolute difference between the corresponding elements in the factual and counterfactual arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in factual and counterfactual are equal.

CounterfactualExplanations.Generators.find_counterfactualMethod
find_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)

Find the first counterfactual index by predicting labels.

Arguments

  • model: The fitted model used for prediction.
  • target_class: Expected target class.
  • counterfactual_data: Data required for counterfactual generation.
  • counterfactual_candidates: The array of counterfactual candidates.

Returns

  • counterfactual: The index of the first counterfactual found.
CounterfactualExplanations.Generators.growing_spheres_generation!Method
growing_spheres_generation(ce::AbstractCounterfactualExplanation)

Generate counterfactual candidates using the growing spheres generation algorithm.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.

The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in ce.M. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.

CounterfactualExplanations.Generators.hMethod
h(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided.

CounterfactualExplanations.Generators.hMethod
h(generator::AbstractGenerator, penalty::Nothing, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where no penalty is provided.

CounterfactualExplanations.Generators.hMethod
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

CounterfactualExplanations.Generators.hMethod
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

CounterfactualExplanations.Generators.hinge_lossMethod
hinge_loss(convergence::AbstractConvergence, ce::AbstractCounterfactualExplanation)

The default hinge loss for any convergence criterion. Can be overridden inside the Convergence module as part of the definition of specific convergence criteria.

CounterfactualExplanations.Generators.hyper_sphere_coordinatesMethod
hyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)

Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.

The implementation follows the Random Point Picking over a sphere algorithm described in the paper: "Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020), presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html

The growing spheres method is originally proposed in the paper: "Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018), presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).

Arguments

  • n_search_samples::Int: The number of search samples (int > 0).
  • instance::AbstractArray: The input point array.
  • low::AbstractFloat: The lower bound (float >= 0, l < h).
  • high::AbstractFloat: The upper bound (float >= 0, h > l).
  • p_norm::Integer: The norm parameter (int >= 1).

Returns

  • candidate_counterfactuals::Array: An array of candidate counterfactuals.
CounterfactualExplanations.Generators.incompatibleMethod
incompatible(AbstractGenerator, AbstractCounterfactualExplanation)

Checks if the generator is incompatible with any of the additional specifications for the counterfactual explanations. By default, generators are assumed to be compatible.

CounterfactualExplanations.Generators.propose_stateMethod
propose_state(
    ::Models.IsDifferentiable,
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

Proposes new state based on backpropagation for gradient-based generators and differentiable models.

CounterfactualExplanations.Generators.ℓMethod
ℓ(generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation)

Overloads the function for the case where a single loss function is provided.

CounterfactualExplanations.Generators.ℓMethod
ℓ(generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation)

Overloads the function for the case where no loss function is provided.

CounterfactualExplanations.Generators.∂hMethod
∂h(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

If the penalty is not provided, it returns 0.0. By default, Zygote never works out the gradient for constants and instead returns 'nothing', so we need to add a manual step to override this behaviour. See here: https://discourse.julialang.org/t/zygote-gradient/26715.

CounterfactualExplanations.Generators.∂ℓMethod
∂ℓ(
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

CounterfactualExplanations.Generators.∇Method
∇(
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivates. It assumes that Zygote.jl has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.

Base.vcatMethod
Base.vcat(bmk1::Benchmark, bmk2::Benchmark)

Vertically concatenates two Benchmark objects.

CounterfactualExplanations.Evaluation.benchmarkMethod
benchmark(
    data::CounterfactualData;
    models::Dict{<:Any,<:Any}=standard_models_catalogue,
    generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    n_individuals::Int=5,
    suppress_training::Bool=false,
    factual::Union{Nothing,RawTargetType}=nothing,
    target::Union{Nothing,RawTargetType}=nothing,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

Runs the benchmarking exercise as follows:

  1. Randomly choose a factual and target label unless specified.
  2. If no pretrained models are provided, it is assumed that a dictionary of callable model objects is provided (by default using the standard_models_catalogue).
  3. Each of these models is then trained on the data.
  4. For each model separately choose n_individuals randomly from the non-target (factual) class. For each generator create a benchmark as in benchmark(xs::Union{AbstractArray,Base.Iterators.Zip}).
  5. Finally, concatenate the results.

If vertical_splits is specified to an integer, the computations are split vertically into vertical_splits chunks. In this case, the results are stored in a temporary directory and concatenated afterwards.

CounterfactualExplanations.Evaluation.benchmarkMethod
benchmark(
    x::Union{AbstractArray,Base.Iterators.Zip},
    target::RawTargetType,
    data::CounterfactualData;
    models::Dict{<:Any,<:AbstractModel},
    generators::Dict{<:Any,<:AbstractGenerator},
    measure::Union{Function,Vector{Function}}=default_measures,
    xids::Union{Nothing,AbstractArray}=nothing,
    dataname::Union{Nothing,Symbol,String}=nothing,
    verbose::Bool=true,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

First generates counterfactual explanations for factual x, the target and data using each of the provided models and generators. Then generates a Benchmark for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation}).

CounterfactualExplanations.Evaluation.benchmarkMethod
benchmark(
    counterfactual_explanations::Vector{CounterfactualExplanation};
    meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    store_ce::Bool=false,
)

Generates a Benchmark for a vector of counterfactual explanations. Optionally meta_data describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data is supplied, it will be automatically inferred. All measure functions are applied to each counterfactual explanation. If store_ce=true, the counterfactual explanations are stored in the benchmark.

CounterfactualExplanations.Evaluation.compute_measureMethod
compute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)

Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce and aggregated using the aggregation function agg.

CounterfactualExplanations.Evaluation.evaluateFunction
evaluate(
    ce::CounterfactualExplanation;
    measure::Union{Function,Vector{Function}}=default_measures,
    agg::Function=mean,
    report_each::Bool=false,
    output_format::Symbol=:Vector,
    pivot_longer::Bool=true
)

Just computes evaluation measures for the counterfactual explanation. By default, no meta data is reported. For report_meta=true, meta data is automatically inferred, unless this overwritten by meta_data. The optional meta_data argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.

CounterfactualExplanations.Evaluation.to_dataframeMethod
evaluate_dataframe(
    ce::CounterfactualExplanation,
    measure::Vector{Function},
    agg::Function,
    report_each::Bool,
    pivot_longer::Bool,
    store_ce::Bool,
)

Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.

CounterfactualExplanations.Evaluation.validityMethod
validity(ce::CounterfactualExplanation; γ=0.5)

Checks of the counterfactual search has been successful with respect to the probability threshold γ. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.