Skip to content

vit_explainability

LinearModelPytorchWrapper(backbone, linear_classifier, example_input, device)

Bases: Module

Pytorch wrapper for scikit-learn linear models.

Parameters:

  • backbone (Module) –

    Backbone

  • linear_classifier (LinearClassifierMixin) –

    ScikitLearn linear classifier model

  • example_input (Tensor) –

    Input example needed to obtain output shape

  • device (device) –

    The device to use. Defaults to "cpu"

Source code in quadra/utils/vit_explainability.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def __init__(
    self,
    backbone: torch.nn.Module,
    linear_classifier: LinearClassifierMixin,
    example_input: torch.Tensor,
    device: torch.device,
):
    super().__init__()
    self.device = device
    self.backbone = backbone.to(device)
    if not isinstance(linear_classifier, LinearClassifierMixin):
        raise TypeError("Classifier is not of type LinearClassifierMixin.")
    self.num_classes = len(linear_classifier.classes_)
    self.linear_classifier = linear_classifier
    with torch.no_grad():
        output = self.backbone(example_input.to(device))
        num_filters = output.shape[-1]

    self.classifier = torch.nn.Linear(num_filters, self.num_classes).to(device)
    self.classifier.weight.data = torch.from_numpy(linear_classifier.coef_).float()
    self.classifier.bias.data = torch.from_numpy(linear_classifier.intercept_).float()

VitAttentionGradRollout(model, attention_layer_names=None, discard_ratio=0.9, classifier=None, example_input=None)

Attention gradient rollout class. Constructor registers hooks to the model's specified layers. Only 4 layers by default given the high load on gpu. Best gradcams obtained using all blocks.

Parameters:

  • model (Module) –

    Pytorch model

  • attention_layer_names (Optional[List[str]], default: None ) –

    On which layers to register the hooks

  • discard_ratio (float, default: 0.9 ) –

    Percentage of elements to discard

  • classifier (Optional[LinearClassifierMixin], default: None ) –

    Scikit-learn classifier. Leave it to None if model already has a classifier on top.

Source code in quadra/utils/vit_explainability.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(  # pylint: disable=W0102
    self,
    model: torch.nn.Module,
    attention_layer_names: Optional[List[str]] = None,
    discard_ratio: float = 0.9,
    classifier: Optional[LinearClassifierMixin] = None,
    example_input: Optional[torch.Tensor] = None,
):
    if attention_layer_names is None:
        attention_layer_names = [
            "blocks.6.attn.attn_drop",
            "blocks.7.attn.attn_drop",
            "blocks.10.attn.attn_drop",
            "blocks.11.attn.attn_drop",
        ]

    if classifier is not None:
        if example_input is None:
            raise ValueError(
                "Must provide an input example to LinearModelPytorchWrapper when classifier is not None"
            )
        self.model = LinearModelPytorchWrapper(
            backbone=model,
            linear_classifier=classifier,
            example_input=example_input,
            device=next(model.parameters()).device,
        )
    else:
        self.model = model  # type: ignore[assignment]

    self.discard_ratio = discard_ratio
    self.f_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
    self.b_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
    for name, module in self.model.named_modules():
        for layer_name in attention_layer_names:
            if layer_name in name:
                self.f_hook_handles.append(module.register_forward_hook(self.get_attention))
                self.b_hook_handles.append(module.register_backward_hook(self.get_attention_gradient))
    self.attentions: List[torch.Tensor] = []
    self.attention_gradients: List[torch.Tensor] = []
    # Activate gradients
    blocks_list = [x.split("blocks")[1].split(".attn")[0] for x in attention_layer_names]
    for name, module in model.named_modules():
        for p in module.parameters():
            if "blocks" in name and any(x in name for x in blocks_list):
                p.requires_grad = True

__call__(input_tensor, targets_list)

Called when the class instance is used as a function.

Parameters:

  • input_tensor (Tensor) –

    Model's input tensor

  • targets_list (List[int]) –

    List of targets. If None, argmax is used

Returns:

  • out ( ndarray ) –

    Batch of output masks

Source code in quadra/utils/vit_explainability.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def __call__(self, input_tensor: torch.Tensor, targets_list: List[int]) -> np.ndarray:
    """Called when the class instance is used as a function.

    Args:
        input_tensor: Model's input tensor
        targets_list: List of targets. If None, argmax is used

    Returns:
        out: Batch of output masks
    """
    self.attentions.clear()
    self.attention_gradients.clear()

    self.model.zero_grad(set_to_none=True)
    self.model.to(input_tensor.device)
    output = self.model(input_tensor).cpu()

    class_mask = torch.zeros(output.shape)
    if targets_list is None:
        targets_list = output.argmax(dim=1)
    class_mask[torch.arange(output.shape[0]), targets_list] = 1
    loss = (output * class_mask).sum()
    loss.backward()
    out = grad_rollout(
        self.attentions,
        self.attention_gradients,
        self.discard_ratio,
        aspect_ratio=(input_tensor.shape[-1] / input_tensor.shape[-2]),
    )

    return out

get_attention(module, inpt, out)

Hook to return attention.

Parameters:

  • module (Module) –

    Torch module

  • inpt (Tensor) –

    Input tensor

  • out (Tensor) –

    Output tensor, in this case the attention

Source code in quadra/utils/vit_explainability.py
240
241
242
243
244
245
246
247
248
249
250
def get_attention(
    self, module: torch.nn.Module, inpt: torch.Tensor, out: torch.Tensor  # pylint: disable=W0613
) -> None:
    """Hook to return attention.

    Args:
        module: Torch module
        inpt: Input tensor
        out: Output tensor, in this case the attention
    """
    self.attentions.append(out.detach().clone().cpu())

get_attention_gradient(module, grad_input, grad_output)

Hook to return attention.

Parameters:

  • module (Module) –

    Torch module

  • grad_input (Tensor) –

    Gradients' input tensor

  • grad_output (Tensor) –

    Gradients' output tensor, in this case the attention

Source code in quadra/utils/vit_explainability.py
252
253
254
255
256
257
258
259
260
261
262
def get_attention_gradient(
    self, module: torch.nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor  # pylint: disable=W0613
) -> None:
    """Hook to return attention.

    Args:
        module: Torch module
        grad_input: Gradients' input tensor
        grad_output: Gradients' output tensor, in this case the attention
    """
    self.attention_gradients.append(grad_input[0].detach().clone().cpu())

VitAttentionRollout(model, attention_layer_names=None, head_fusion='mean', discard_ratio=0.9)

Attention gradient rollout class. Constructor registers hooks to the model's specified layers. Only 4 layers by default given the high load on gpu. Best gradcams obtained using all blocks.

Parameters:

  • model (Module) –

    Model

  • attention_layer_names (Optional[List[str]], default: None ) –

    On which layers to register the hook

  • head_fusion (str, default: 'mean' ) –

    Strategy of fusion for attention heads

  • discard_ratio (float, default: 0.9 ) –

    Percentage of elements to discard

Source code in quadra/utils/vit_explainability.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    model: torch.nn.Module,
    attention_layer_names: Optional[List[str]] = None,
    head_fusion: str = "mean",
    discard_ratio: float = 0.9,
):
    if attention_layer_names is None:
        attention_layer_names = [
            "blocks.6.attn.attn_drop",
            "blocks.7.attn.attn_drop",
            "blocks.10.attn.attn_drop",
            "blocks.11.attn.attn_drop",
        ]
    self.model = model
    self.head_fusion = head_fusion
    self.discard_ratio = discard_ratio
    self.f_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
    for name, module in self.model.named_modules():
        for layer_name in attention_layer_names:
            if layer_name in name:
                self.f_hook_handles.append(module.register_forward_hook(self.get_attention))
    self.attentions: List[torch.Tensor] = []

__call__(input_tensor)

Called when the class instance is used as a function.

Parameters:

  • input_tensor (Tensor) –

    Input tensor

Returns:

  • out ( ndarray ) –

    Batch of output masks

Source code in quadra/utils/vit_explainability.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __call__(self, input_tensor: torch.Tensor) -> np.ndarray:
    """Called when the class instance is used as a function.

    Args:
        input_tensor: Input tensor

    Returns:
        out: Batch of output masks
    """
    self.attentions.clear()
    with torch.no_grad():
        self.model(input_tensor)
    out = rollout(
        self.attentions,
        self.discard_ratio,
        self.head_fusion,
        aspect_ratio=(input_tensor.shape[-1] / input_tensor.shape[-2]),
    )

    return out

get_attention(module, inpt, out)

Hook to return attention.

Parameters:

  • module (Module) –

    Torch module

  • inpt (Tensor) –

    Input tensor

  • out (Tensor) –

    Output tensor, in this case the attention

Source code in quadra/utils/vit_explainability.py
 99
100
101
102
103
104
105
106
107
108
109
def get_attention(
    self, module: torch.nn.Module, inpt: torch.Tensor, out: torch.Tensor  # pylint: disable=W0613
) -> None:
    """Hook to return attention.

    Args:
        module: Torch module
        inpt: Input tensor
        out: Output tensor, in this case the attention
    """
    self.attentions.append(out.detach().clone().cpu())

grad_rollout(attentions, gradients, discard_ratio=0.9, aspect_ratio=1.0)

Apply gradient rollout on Attention matrices.

Parameters:

  • attentions (List[Tensor]) –

    Attention matrices

  • gradients (List[Tensor]) –

    Target class gradient matrices

  • discard_ratio (float, default: 0.9 ) –

    Percentage of elements to discard

  • aspect_ratio (float, default: 1.0 ) –

    Model inputs' width divided by height

Returns:

  • mask ( ndarray ) –

    Output mask, still needs a resize

Source code in quadra/utils/vit_explainability.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def grad_rollout(
    attentions: List[torch.Tensor], gradients: List[torch.Tensor], discard_ratio: float = 0.9, aspect_ratio: float = 1.0
) -> np.ndarray:
    """Apply gradient rollout on Attention matrices.

    Args:
        attentions: Attention matrices
        gradients: Target class gradient matrices
        discard_ratio: Percentage of elements to discard
        aspect_ratio: Model inputs' width divided by height

    Returns:
        mask: Output mask, still needs a resize
    """
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):
            weights = grad
            attention_heads_fused = torch.mean((attention * weights), dim=1)
            attention_heads_fused[attention_heads_fused < 0] = 0
            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
            flat.scatter_(-1, indices, 0)
            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1).unsqueeze(1)
            result = torch.matmul(a, result)
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[:, 0, 1:]
    batch_size = mask.size(0)
    # TODO: Non squared input-size handling can be improved. Not easy though
    height = math.floor((mask.size(-1) / aspect_ratio) ** 0.5)
    total_size = mask.size(-1)
    width = math.floor(total_size / height)
    if mask.size(-1) > (height * width):
        to_remove = mask.size(-1) - (height * width)
        mask = mask[:, :-to_remove]
    mask = mask.reshape(batch_size, height, width).numpy()
    mask = mask / mask.max(axis=(1, 2), keepdims=True)

    return mask

rollout(attentions, discard_ratio=0.9, head_fusion='mean', aspect_ratio=1.0)

Apply rollout on Attention matrices.

Parameters:

  • attentions (List[Tensor]) –

    List of Attention matrices coming from different blocks

  • discard_ratio (float, default: 0.9 ) –

    Percentage of elements to discard

  • head_fusion (str, default: 'mean' ) –

    Strategy of fusion of attention heads

  • aspect_ratio (float, default: 1.0 ) –

    Model inputs' width divided by height

Returns:

  • mask ( ndarray ) –

    Output mask, still needs a resize

Source code in quadra/utils/vit_explainability.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def rollout(
    attentions: List[torch.Tensor], discard_ratio: float = 0.9, head_fusion: str = "mean", aspect_ratio: float = 1.0
) -> np.ndarray:
    """Apply rollout on Attention matrices.

    Args:
        attentions: List of Attention matrices coming from different blocks
        discard_ratio: Percentage of elements to discard
        head_fusion: Strategy of fusion of attention heads
        aspect_ratio: Model inputs' width divided by height

    Returns:
        mask: Output mask, still needs a resize
    """
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(dim=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(dim=1)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(dim=1)[0]
            else:
                raise ValueError("Attention head fusion type Not supported")
            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
            flat.scatter_(-1, indices, 0)
            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1).unsqueeze(1)
            result = torch.matmul(a, result)
    # Look at the total attention between the class token and the image patches
    mask = result[:, 0, 1:]
    batch_size = mask.size(0)
    # TODO: Non squared input-size handling can be improved. Not easy though
    height = math.floor((mask.size(-1) / aspect_ratio) ** 0.5)
    total_size = mask.size(-1)
    width = math.floor(total_size / height)
    if mask.size(-1) > (height * width):
        to_remove = mask.size(-1) - (height * width)
        mask = mask[:, :-to_remove]
    mask = mask.reshape(batch_size, height, width).numpy()
    mask = mask / mask.max(axis=(1, 2), keepdims=True)

    return mask