LocalModel fits locally weighted linear regression models (logistic regression for classification) to explain single predictions of a prediction model.

Details

A weighted glm is fitted with the machine learning model prediction as target. Data points are weighted by their proximity to the instance to be explained, using the gower proximity measure. L1-regularization is used to make the results sparse.

The resulting model can be seen as a surrogate for the machine learning model, which is only valid for that one point. Categorical features are binarized, depending on the category of the instance to be explained: 1 if the category is the same, 0 otherwise.

Please note that scaling continuous features in the machine learning method might be advisable when using LIME as an interpretation technique. LIME uses a distance measure to compute proximity weights for the weighted glm. Hence, the original scale of the features may influence the distance measure and therewith LIME results.

To learn more about local models, read the Interpretable Machine Learning book: https://christophm.github.io/interpretable-ml-book/lime.html

The approach is similar to LIME, but has the following differences:

  • Distance measure: Uses as default the gower proximity (= 1 - gower distance) instead of a kernel based on the Euclidean distance. Has the advantage to have a meaningful neighborhood and no kernel width to tune. When the distance is not "gower", then the stats::dist() function with the chosen method will be used, and turned into a similarity measure: \(sqrt(exp(-(distance^2) / (kernel.width^2)))\).

  • Sampling: Uses the original data instead of sampling from normal distributions. Has the advantage to follow the original data distribution.

  • Visualization: Plots effects instead of betas. Both are the same for binary features, but ared different for numerical features. For numerical features, plotting the betas makes no sense, because a negative beta might still increase the prediction when the feature value is also negative.

To learn more about local surrogate models, read the Interpretable Machine Learning book: https://christophm.github.io/interpretable-ml-book/lime.html

References

Ribeiro, M. T., Singh, S., & Guestrin, C. (2016). "Why Should I Trust You?": Explaining the Predictions of Any Classifier. Retrieved from http://arxiv.org/abs/1602.04938

Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623--637.

See also

plot.LocalModel and predict.LocalModel

Shapley can also be used to explain single predictions

The package lime with the original implementation

Super class

iml::InterpretationMethod -> LocalModel

Public fields

x.interest

data.frame
Single row with the instance to be explained.

k

numeric(1)
The number of features as set by the user.

model

glmnet
The fitted local model.

best.fit.index

numeric(1)
The index of the best glmnet fit.

Methods

Inherited methods


Method new()

Create a Local Model object.

Usage

LocalModel$new(
  predictor,
  x.interest,
  dist.fun = "gower",
  gower.power = 1,
  kernel.width = NULL,
  k = 3
)

Arguments

predictor

Predictor
The object (created with Predictor$new()) holding the machine learning model and the data.

x.interest

data.frame
Single row with the instance to be explained.

dist.fun

character(1))
The name of the distance function for computing proximities (weights in the linear model). Defaults to "gower". Otherwise will be forwarded to stats::dist.

gower.power

(numeric(1))
The calculated gower proximity will be raised to the power of this value. Can be used to specify the size of the neighborhood for the LocalModel (similar to kernel.width for the euclidean distance).

kernel.width

(numeric(1))
The width of the kernel for the proximity computation. Only used if dist.fun is not "gower".

k

numeric(1)
The number of features.

Returns

data.frame
Results with the feature names (feature) and contributions to the prediction.


Method predict()

Method to predict new data with the local model See also predict.LocalModel.

Usage

LocalModel$predict(newdata = NULL, ...)

Arguments

newdata

data.frame
Data to predict on.

...

Not used


Method explain()

Set a new data point to explain.

Usage

LocalModel$explain(x.interest)

Arguments

x.interest

data.frame
Single row with the instance to be explained.


Method clone()

The objects of this class are cloneable with this method.

Usage

LocalModel$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library("randomForest")
# First we fit a machine learning model on the Boston housing data
data("Boston", package = "MASS")
X <- Boston[-which(names(Boston) == "medv")]
rf <- randomForest(medv ~ ., data = Boston, ntree = 50)
mod <- Predictor$new(rf, data = X)

# Explain the first instance of the dataset with the LocalModel method:
x.interest <- X[1, ]
lemon <- LocalModel$new(mod, x.interest = x.interest, k = 2)
#> Loading required package: glmnet
#> Loading required package: Matrix
#> Loaded glmnet 4.1-4
#> Loading required package: gower
lemon
#> Interpretation method:  LocalModel 
#> 
#> 
#> Analysed predictor: 
#> Prediction task: unknown 
#> 
#> 
#> Analysed data:
#> Sampling from data.frame with 506 rows and 13 columns.
#> 
#> 
#> Head of results:
#>             beta x.recoded    effect x.original feature feature.value
#> rm     2.9966652     6.575 19.703074      6.575      rm      rm=6.575
#> lstat -0.3041663     4.980 -1.514748       4.98   lstat    lstat=4.98

# Look at the results in a table
lemon$results
#>             beta x.recoded    effect x.original feature feature.value
#> rm     2.9966652     6.575 19.703074      6.575      rm      rm=6.575
#> lstat -0.3041663     4.980 -1.514748       4.98   lstat    lstat=4.98
# Or as a plot
plot(lemon)


# Reuse the object with a new instance to explain
lemon$x.interest
#>      crim zn indus chas   nox    rm  age  dis rad tax ptratio black lstat
#> 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.09   1 296    15.3 396.9  4.98
lemon$explain(X[2, ])
lemon$x.interest
#>      crim zn indus chas   nox    rm  age    dis rad tax ptratio black lstat
#> 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.9  9.14
plot(lemon)


# LocalModel also works with multiclass classification
rf <- randomForest(Species ~ ., data = iris, ntree = 50)
X <- iris[-which(names(iris) == "Species")]
mod <- Predictor$new(rf, data = X, type = "prob", class = "setosa")

# Then we explain the first instance of the dataset with the LocalModel method:
lemon <- LocalModel$new(mod, x.interest = X[1, ], k = 2)
lemon$results
#>                    beta x.recoded     effect x.original      feature
#> Sepal.Width   0.1783374       3.5  0.6241810        3.5  Sepal.Width
#> Petal.Length -0.2159960       1.4 -0.3023944        1.4 Petal.Length
#>                 feature.value
#> Sepal.Width   Sepal.Width=3.5
#> Petal.Length Petal.Length=1.4
plot(lemon)