本教程演示如何应用 Captum 库中的 TracInCP 算法来实现影响样本的可解释性。TracInCP 计算给定训练样本对给定测试样本的影响分数。通俗地说,它代表了如果从训练数据集中移除该训练样本并重新训练模型,给定测试样本的损失(loss)会增加多少。此功能可应用于以下两个用例:
TracInCP 可用于任何具有多个模型检查点(checkpoint)的已训练 Pytorch 模型。
注意: 在运行本教程之前,请执行以下操作:
目前,Captum 提供 3 种实现,它们都实现了相同的 API。更具体地说,它们定义了一个 influence 方法,可以在两种不同的模式下使用:
这三种不同的实现在以下类中定义:
TracInCP:计算影响分数时考虑所有指定层中的梯度。指定过多层会减慢所有三种模式的执行速度。TracInCPFast:在 TracIn 论文的附录 F 中,作者展示了如果计算影响分数时仅考虑最后一个全连接层的梯度,则可以使用计算技巧比天真地应用反向传播计算梯度更快地完成计算。TracInCPFast 利用该技巧计算影响分数,仅考虑最后一个全连接层。如果您想相对于 TracInCP 减少时间和内存消耗,TracInCPFast 非常有用。TracInCPFastRandProj:前两个类不适用于“交互式”使用,因为在影响分数模式或前 K 个最具影响力模式下,每次调用 influence 所需的时间与训练数据集大小成正比。另一方面,TracInCPFastRandProj 支持“交互式”使用,即针对这两种模式的 influence 调用时间为常数。代价是在 TracInCPFastRandProj.__init__ 中,需要进行预处理,将与每个训练样本相关的嵌入(embeddings)存储到最近邻数据结构中。此预处理耗费的时间和内存与训练数据集大小成正比。此外,可以应用随机投影(random projections)来减少内存使用,代价是这两种模式中使用的影响分数仅为近似正确。与 TracInCPFast 一样,该类仅考虑最后一个全连接层的梯度,如果您想相对于 TracInCP 减少时间和内存消耗,它将非常有用。%matplotlib inline
%load_ext autoreload
%autoreload 2
import datetime
import glob
import os
import pickle
import warnings
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from captum.influence import TracInCP, TracInCPFast, TracInCPFastRandProj
from sklearn.metrics import auc, roc_curve
from torch.utils.data import DataLoader, Dataset, Subset
warnings.filterwarnings("ignore")
首先,我们将演示 TracInCP 识别影响样本的能力,即在“前 K 个最具影响力”模式下使用 influence 方法。为此,我们需要 3 个组件:
net。net 的 Pytorch Dataset。为此我们将使用 correct_dataset,即原始 CIFAR-10 训练集拆分。Dataset,即 test_dataset。为此我们将使用原始 CIFAR-10 验证集拆分。这对于监控训练也很有用。correct_dataset 训练 net 得到的检查点。我们将进行训练并保存检查点。net¶我们将使用来自以下教程的一个相对简单的模型:https://pytorch.ac.cn/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py。
我们首先定义 net 的架构:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x
在下方的单元格中,我们初始化 net。
net = Net()
由于两个都是图像数据集,我们将首先定义 normalize 和 inverse_normalize 变换,分别用于图像到输入的转换以及输入到图像的转换。
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
inverse_normalize = transforms.Compose([
transforms.Normalize(mean = [0., 0., 0.], std = [1/0.5, 1/0.5, 1/0.5]),
transforms.Normalize(mean = [-0.5, -0.5, -0.5], std = [1., 1., 1.]),
])
correct_dataset¶correct_dataset_path = "data/cifar_10"
correct_dataset = torchvision.datasets.CIFAR10(root=correct_dataset_path, train=True, download=True, transform=normalize)
Files already downloaded and verified
test_dataset¶这将与 correct_dataset 相同,共享相同的路径和变换。唯一的区别是它使用验证集拆分。
test_dataset = torchvision.datasets.CIFAR10(root=correct_dataset_path, train=False, download=True, transform=normalize)
Files already downloaded and verified
我们将通过在 correct_dataset 上训练 net 26 个 epoch 来获取检查点。通常,应至少有 5 个检查点,它们可以均匀分布在整个训练过程中,或者更好的是,选择损失下降明显的 epoch。
我们首先定义一个训练函数,它复制自上述教程。
def train(net, num_epochs, train_dataloader, test_dataloader, checkpoints_dir, save_every):
start_time = datetime.datetime.now()
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epochs): # loop over the dataset multiple times
epoch_loss = 0.0
running_loss = 0.0
for i, data in enumerate(train_dataloader):
# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if (i + 1) % 100 == 0: # print every 100 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 100))
epoch_loss += running_loss
running_loss = 0.0
if epoch % save_every == 0:
checkpoint_name = "-".join(["checkpoint", str(epoch) + ".pt"])
torch.save(
{
"epoch": epoch,
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss,
},
os.path.join(checkpoints_dir, checkpoint_name),
)
# Calcualate validation accuracy
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in test_dataloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Accuracy of the network on test set at epoch %d: %d %%" % (epoch, 100 * correct / total))
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print("Finished training in %.2f minutes" % total_minutes)
然后我们定义保存检查点的文件夹。稍后运行 TracInCP 算法时将需要此文件夹。
correct_dataset_checkpoints_dir = os.path.join("checkpoints", "cifar_10_correct_dataset")
最后,我们训练模型,将 correct_dataset 和 test_dataset 转换为 DataLoader,并每隔 5 个 epoch 保存一次检查点。在本教程中,我们已将此次训练的检查点保存在 AWS S3 上,您可以直接下载这些检查点,而无需进行耗时的训练。如果您想自己进行训练,请将下一个单元格中的 do_training 标志设置为 True。
num_epochs = 26
do_training = False # change to `True` if you want to do training
if do_training:
train(net, num_epochs, DataLoader(correct_dataset, batch_size=128, shuffle=True), DataLoader(test_dataset, batch_size=128, shuffle=True), correct_dataset_checkpoints_dir, save_every=5)
elif not os.path.exists(correct_dataset_checkpoints_dir):
# this should download the zipped folder of checkpoints from the S3 bucket
# then unzip the folder to produce checkpoints in the folder `checkpoints/cifar_10_correct_dataset`
# this is done if checkpoints do not already exist in the folder
# if the below commands do not work, please manually download and unzip the folder to produce checkpoints in that folder
os.makedirs(correct_dataset_checkpoints_dir)
!wget https://pytorch.s3.amazonaws.com/models/captum/influence-tutorials/cifar_10_correct_dataset.zip -O checkpoints/cifar_10_correct_dataset.zip
!unzip -o checkpoints/cifar_10_correct_dataset.zip -d checkpoints
我们将检查点列表 correct_dataset_checkpoint_paths 定义为训练中的所有检查点。
correct_dataset_checkpoint_paths = glob.glob(os.path.join(correct_dataset_checkpoints_dir, "*.pt"))
我们还定义了一个函数,用于将给定的检查点加载到给定的模型中。这对现在以及所有 TracInCP 实现都很有用。在 TracInCP 实现中使用时,该函数应返回检查点处的学习率。但是,如果无法获得该学习率,直接像我们一样返回 1 也是安全的,因为事实证明 TracInCP 实现对该学习率并不敏感。
def checkpoints_load_func(net, path):
weights = torch.load(path)
net.load_state_dict(weights["model_state_dict"])
return 1.
我们首先为 net 加载最后一个检查点,以便我们在下一个单元格中进行的预测针对的是训练后的模型。我们将这最后一个检查点保存为 correct_dataset_final_checkpoint,因为我们稍后会重复使用它。
correct_dataset_final_checkpoint = os.path.join(correct_dataset_checkpoints_dir, "-".join(['checkpoint', str(num_epochs - 1) + '.pt']))
checkpoints_load_func(net, correct_dataset_final_checkpoint)
1.0
现在,我们定义 test_examples_features,即用于识别影响样本的一批测试样本的特征,并存储正确标签以及预测标签。
test_examples_indices = [0,1,2,3]
test_examples_features = torch.stack([test_dataset[i][0] for i in test_examples_indices])
test_examples_predicted_probs, test_examples_predicted_labels = torch.max(F.softmax(net(test_examples_features), dim=1), dim=1)
test_examples_true_labels = torch.Tensor([test_dataset[i][1] for i in test_examples_indices]).long()
回想一下,TracInCP 算法有几种实现方式。特别是,TracInCP 比 TracInCPFast 和 TracInCPFastRandProj 更耗费时间和内存。在本教程中,为了节省时间,我们将仅使用 TracInCPFast 和 TracInCPFastRandProj。
在 TracInCPFast 和 TracInCPFastRandProj 之间做出选择时,请记住 TracInCPFastRandProj 适用于“交互式”使用,即当需要多次调用“影响分数”和“前 K 个最具影响力”模式下的 influence 方法时。作为“交互式”使用能力的代价,TracInCPFastRandProj 需要初始预处理,这可能非常耗费时间和内存。另一方面,TracInCPFast 不支持“交互式”使用,但避免了初始预处理。
TracInCPFast 实例¶我们将首先演示 TracInCPFast 的使用,以避免初始预处理(因为我们只调用一次 influence 方法,不会利用“交互式”使用功能)。
为了完整定义 TracInCPFast 实现,还需要定义几个参数:
final_fc_layer:指向最后一个全连接层的引用或名称,其梯度将用于计算影响分数。这必须是最后一层。loss_fn:训练中使用的损失函数。batch_size:用于计算影响分数的训练数据的批大小。它不影响计算出的实际影响分数,但会影响计算效率。特别是,遍历训练数据所需的批次越少,所有模式下的 influence 速度就越快。这是因为 influence 会为每个批次加载一次模型检查点。因此,batch_size 应设置得较大,但不能过大(否则内存会溢出)。vectorize:是否使用加速 Jacobian 计算的实验性功能。仅在 PyTorch 版本 >1.6 中可用。我们现在准备好创建 TracInCPFast 实例:
tracin_cp_fast = TracInCPFast(
model=net,
final_fc_layer=list(net.children())[-1],
train_dataset=correct_dataset,
checkpoints=correct_dataset_checkpoint_paths,
checkpoints_load_func=checkpoints_load_func,
loss_fn=nn.CrossEntropyLoss(reduction="sum"),
batch_size=2048,
vectorize=False,
)
TracInCPFast 计算支持者/反对者¶现在,我们调用 tracin_cp_fast 的 influence 方法来计算由 test_examples_features 和 test_examples_true_labels 表示的测试样本的影响样本。我们需要通过 proponents 布尔参数指定我们需要支持者还是反对者,并通过 k 参数指定每个测试样本返回多少个影响样本。请注意,必须指定 k。否则,将运行“影响分数”模式。此调用应耗时少于 2 分钟。
请注意,我们将测试样本作为 单个 元组传递。这是因为对于所有实现,当我们向 influence 方法传递单个批次 batch 时,我们假设 batch[-1] 包含该批次的标签,而 model(*(batch[0:-1])) 生成该批次的预测,从而使 batch[0:-1] 包含该批次的特征。这一约定是在最近的 API 更改中引入的。
此调用返回一个 namedtuple,其有序元素为 (indices, influence_scores)。indices 是形状为 (test_batch_size, k) 的 2D 张量,其中 test_batch_size 是 test_examples_batch 中测试样本的数量。influence_scores 的形状相同,但按顺序存储每个测试样本的支持者/反对者的影响分数。例如,如果 proponents 为 True,则 influence_scores[i][j] 是对测试样本 i 具有第 j 个最高正影响分数的训练样本的影响分数。
k = 10
start_time = datetime.datetime.now()
proponents_indices, proponents_influence_scores = tracin_cp_fast.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=True
)
opponents_indices, opponents_influence_scores = tracin_cp_fast.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=False
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Computed proponents / opponents over a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Computed proponents / opponents over a dataset of 50000 examples in 1.11 minutes
为了显示结果,我们定义了几个辅助函数,用于显示测试样本、显示一组训练样本,以及一个将数据集中的张量转换为适用于 matplotlib imshow 函数的辅助变换,以及一个从数值标签(即 4)到类别(即“cat”)的映射。
label_to_class = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
imshow_transform = lambda tensor_in_dataset: inverse_normalize(tensor_in_dataset.squeeze()).permute(1, 2, 0)
def display_test_example(example, true_label, predicted_label, predicted_prob, label_to_class):
fig, ax = plt.subplots()
print('true_class:', label_to_class[true_label])
print('predicted_class:', label_to_class[predicted_label])
print('predicted_prob', predicted_prob)
ax.imshow(torch.clip(imshow_transform(example), 0, 1))
plt.show()
def display_training_examples(examples, true_labels, label_to_class, figsize=(10,4)):
fig = plt.figure(figsize=figsize)
num_examples = len(examples)
for i in range(num_examples):
ax = fig.add_subplot(1, num_examples, i+1)
ax.imshow(torch.clip(imshow_transform(examples[i]), 0, 1))
ax.set_title(label_to_class[true_labels[i]])
plt.show()
return fig
def display_proponents_and_opponents(test_examples_batch, proponents_indices, opponents_indices, test_examples_true_labels, test_examples_predicted_labels, test_examples_predicted_probs):
for (
test_example,
test_example_proponents,
test_example_opponents,
test_example_true_label,
test_example_predicted_label,
test_example_predicted_prob,
) in zip(
test_examples_batch,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
):
print("test example:")
display_test_example(
test_example,
test_example_true_label,
test_example_predicted_label,
test_example_predicted_prob,
label_to_class,
)
print("proponents:")
test_example_proponents_tensors, test_example_proponents_labels = zip(
*[correct_dataset[i] for i in test_example_proponents]
)
display_training_examples(
test_example_proponents_tensors, test_example_proponents_labels, label_to_class, figsize=(20, 8)
)
print("opponents:")
test_example_opponents_tensors, test_example_opponents_labels = zip(
*[correct_dataset[i] for i in test_example_opponents]
)
display_training_examples(
test_example_opponents_tensors, test_example_opponents_labels, label_to_class, figsize=(20, 8)
)
我们可以显示每个测试样本的支持者和反对者:
display_proponents_and_opponents(
test_examples_features,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
)
test example: true_class: cat predicted_class: cat predicted_prob tensor(0.4126, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.5685, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.3574, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: plane predicted_class: ship predicted_prob tensor(0.6398, grad_fn=<UnbindBackward0>)
proponents:
opponents:
我们看到结果符合直觉。例如,测试样本“猫”的支持者也都是被标记为“猫”的猫。另一方面,反对者则是看起来有点像猫但被标记为其他动物(如“狗”)的动物。因此,这些反对者的存在会使测试样本的预测偏离“猫”。
TracInCPFastRandProj 实例¶我们还定义并使用 TracInCPFastRandProj 实例来展示其优缺点。请注意,由于 TracInCPFastRandProj 将训练数据集中与每个样本相关的嵌入存储在最近邻数据结构中,__init__ 有 2 个新参数:
nearest_neighbors:这是内部用于快速查找支持者/反对者的最近邻类(测试样本的支持者/反对者是那些其某种嵌入与测试样本相似/不相似的样本,更多详情请参阅 TracIn 论文)。目前仅提供一个最近邻类:AnnoyNearestNeighbors,它封装了 Annoy 库。该类有一个参数:num_trees,即要使用的树的数量。增加此数量可以更精确地计算最近邻,但创建树需要更长的设置时间以及内存。projection_dim:嵌入维度可能过高并需要过多内存。可以使用随机投影来降低这些嵌入的维度。此参数指定了该维度(对应于 TracIn 论文附录第 15 页中的变量 d)。更详细地说,嵌入是几个“检查点嵌入”的级联,每个“检查点嵌入”对应一个特定的检查点。因此,嵌入的维度实际上是 projection_dim 乘以所使用的检查点数量。注意:初始化大约需要 10 分钟,因此请随意跳过与 TracInCPFastRandProj 相关的教程部分。
from captum.influence._utils.nearest_neighbors import AnnoyNearestNeighbors
start_time = datetime.datetime.now()
tracin_cp_fast_rand_proj = TracInCPFastRandProj(
model=net,
final_fc_layer=list(net.children())[-1],
train_dataset=correct_dataset,
checkpoints=correct_dataset_checkpoint_paths,
checkpoints_load_func=checkpoints_load_func,
loss_fn=nn.CrossEntropyLoss(reduction="sum"),
batch_size=128,
nearest_neighbors=AnnoyNearestNeighbors(num_trees=100),
projection_dim=100,
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Performed pre-processing of a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Performed pre-processing of a dataset of 50000 examples in 4.98 minutes
TracInCPFastRandProj 计算支持者/反对者¶与之前一样,我们可以使用此 TracInCPFastRandProj 实例的 influence 方法来计算支持者/反对者。与 TracInCPFast 实例不同,由于在初始化期间进行了预处理,此计算应该非常快。
k = 10
start_time = datetime.datetime.now()
proponents_indices, proponents_influence_scores = tracin_cp_fast_rand_proj.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=True
)
opponents_indices, opponents_influence_scores = tracin_cp_fast_rand_proj.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=False
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Computed proponents / opponents over a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Computed proponents / opponents over a dataset of 50000 examples in 0.01 minutes
我们可以显示每个测试样本的支持者和反对者:
display_proponents_and_opponents(
test_examples_features,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
)
test example: true_class: cat predicted_class: cat predicted_prob tensor(0.4126, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.5685, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.3574, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: plane predicted_class: ship predicted_prob tensor(0.6398, grad_fn=<UnbindBackward0>)
proponents:
opponents:
我们看到支持者/反对者与之前不同,但它们仍然是合理的。
现在,我们将演示 TracInCP 识别标注错误数据的能力。与之前一样,我们需要 3 个组件:
net。net 的 Pytorch Dataset,即 mislabelled_dataset。Dataset,即 test_dataset。为此我们将使用原始 CIFAR-10 验证集拆分。与使用 TracInCP 为某些测试样本识别影响样本不同,我们此处仅将其用于监控训练;我们只是在识别训练数据中的标注错误样本。另外请注意,与 mislabelled_dataset 不同,test_dataset 不会有标注错误的样本。mislabelled_dataset 训练 net 得到的检查点。net¶我们需要再次初始化 net,因为目前 net 加载的是使用 correct_dataset 训练的参数,但我们现在想使用 mislabelled_dataset 从头开始训练 net。
net = Net()
mislabelled_dataset¶我们现在通过人为地在 correct_dataset 中引入标注错误样本来定义 mislabelled_dataset。使用人工数据可以让我们知道样本是否真实标注错误的地面真值,从而进行评估。我们通过以下程序从 correct_dataset 创建 mislabelled_dataset:我们将使用 correct_dataset 训练的 Pytorch 模型初始化为 correct_dataset_net。对于 correct_dataset 中 10% 的样本,我们使用 correct_dataset_net 预测样本属于每个类别的概率。然后我们将标签更改为概率最高但 错误 的标签。
请注意,要了解 mislabelled_dataset 中哪些样本标注错误的地面真值,我们可以比较 mislabelled_dataset 和 correct_dataset 之间的标签。另请注意,由于两个数据集具有相同的特征,mislabelled_dataset 是根据 correct_dataset 定义的。
首先,我们初始化 correct_dataset_net,加载使用 correct_dataset 训练的参数(我们将使用之前的最后一个检查点 correct_dataset_final_checkpoint)。
correct_dataset_net = Net()
checkpoints_load_func(correct_dataset_net, correct_dataset_final_checkpoint)
1.0
然后,我们为 correct_dataset 中的每个样本生成错误标签并提取正确标签。我们需要正确标签,因为 incorrect_dataset 中的某些样本仍将被正确标注。这应该耗时少于 10 分钟。
start_time = datetime.datetime.now()
incorrect_labels = []
correct_labels = []
correct_dataset_dataloader = DataLoader(correct_dataset, batch_size=128, shuffle=False)
for i, (batch_features, batch_correct_labels) in enumerate(correct_dataset_dataloader):
# get predicted probabilities of each class
batch_predictions = torch.nn.functional.softmax(correct_dataset_net(batch_features), dim=1)
# set the predicted probability of the correct class to 0
batch_predictions[torch.arange(0, len(batch_predictions)), batch_correct_labels] = 0
# most probable incorrect label is the remaining class with the highest predicted probability
batch_incorrect_labels = torch.argmax(batch_predictions, dim=1)
incorrect_labels.append(batch_incorrect_labels)
correct_labels.append(batch_correct_labels)
incorrect_labels = torch.cat(incorrect_labels)
correct_labels = torch.cat(correct_labels)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print("Generated incorrect labels in %.2f minutes" % total_minutes)
Generated incorrect labels in 0.42 minutes
现在,我们为 mislabelled_dataset 创建标签。10% 来自 incorrect_labels,其余来自 correct_labels。
mislabelled_proportion = 0.10
use_incorrect = torch.rand(len(incorrect_labels)) < mislabelled_proportion
mislabelled_dataset_labels = (use_incorrect * incorrect_labels) + ((~use_incorrect) * correct_labels)
最后,定义 mislabelled_dataset,它将取决于 correct_dataset,因为它们共享特征。
class MislabelledDataset(Dataset):
def __init__(self, correct_dataset: Dataset, mislabelled_dataset_labels: torch.Tensor):
self.correct_dataset, self.mislabelled_dataset_labels = correct_dataset, mislabelled_dataset_labels
def __getitem__(self, i):
return self.correct_dataset[i][0], self.mislabelled_dataset_labels[i]
def __len__(self):
return len(self.correct_dataset)
mislabelled_dataset = MislabelledDataset(correct_dataset, mislabelled_dataset_labels)
我们现在可以通过比较从 mislabelled_dataset 提取的标签与从 correct_dataset 提取的标签,来获取 mislabelled_dataset 中样本是否标注错误的地面真值,并验证 mislabelled_dataset 中确实有约 10% 的样本标注错误。该地面真值保存为 is_mislabelled,稍后将用于评估。
_incorrect_dataset_labels = torch.Tensor([mislabelled_dataset[i][1] for i in range(len(mislabelled_dataset))])
_correct_dataset_labels = torch.Tensor([correct_dataset[i][1] for i in range(len(correct_dataset))])
is_mislabelled = _incorrect_dataset_labels != _correct_dataset_labels
print("%.2f percent of the labels in `incorrect_dataset` are mislabelled." % (100 * torch.mean(is_mislabelled.float())))
10.01 percent of the labels in `incorrect_dataset` are mislabelled.
为了检测 mislabelled_dataset 中的标注错误样本,我们需要首先使用 mislabelled_dataset 训练一个模型,并保存检查点。我们将在 mislabelled_dataset 上对 net 进行 101 个 epoch 的训练。
我们首先定义保存检查点的文件夹。稍后运行 TracInCP 算法时将需要此文件夹。
mislabelled_dataset_checkpoints_dir = os.path.join("checkpoints", "cifar_10_mislabelled_dataset")
最后,我们训练模型,将 mislabelled_dataset 和 test_dataset 转换为 DataLoader,并每隔 20 个 epoch 保存一次检查点。在本教程中,我们已将此次训练的检查点保存在 AWS S3 上,您可以直接下载这些检查点,而无需进行耗时的训练。如果您想自己进行训练,请将下一个单元格中的 do_training 标志设置为 True。
num_epochs = 101
do_training = False # change to `True` if you want to do training
if do_training:
train(net, num_epochs, DataLoader(mislabelled_dataset, batch_size=128, shuffle=True), DataLoader(test_dataset, batch_size=128, shuffle=True), mislabelled_dataset_checkpoints_dir, save_every=20)
elif not os.path.exists(mislabelled_dataset_checkpoints_dir):
# this should download the zipped folder of checkpoints from the S3 bucket,
# then unzip the folder to produce checkpoints in the folder `checkpoints/cifar_10_mislabelled_dataset`
# this is done if checkpoints do not already exist in the folder
# if the below commands do not work, please manually download and unzip the folder to produce checkpoints in that folder
os.makedirs(mislabelled_dataset_checkpoints_dir)
!wget https://pytorch.s3.amazonaws.com/models/captum/influence-tutorials/cifar_10_mislabelled_dataset.zip -O checkpoints/cifar_10_mislabelled_dataset.zip
!unzip -o checkpoints/cifar_10_mislabelled_dataset.zip -d checkpoints
--2022-11-15 14:58:58-- https://pytorch.s3.amazonaws.com/models/captum/influence-tutorials/cifar_10_mislabelled_dataset.zip Resolving fwdproxy (fwdproxy)... 2401:db00:12ff:ff13:face:b00c:0:1e10 Connecting to fwdproxy (fwdproxy)|2401:db00:12ff:ff13:face:b00c:0:1e10|:8080... connected. Proxy request sent, awaiting response... 200 OK Length: 2780482 (2.7M) [application/zip] Saving to: ‘checkpoints/cifar_10_mislabelled_dataset.zip’ checkpoints/cifar_1 100%[===================>] 2.65M 824KB/s in 3.3s 2022-11-15 14:59:02 (824 KB/s) - ‘checkpoints/cifar_10_mislabelled_dataset.zip’ saved [2780482/2780482] Archive: checkpoints/cifar_10_mislabelled_dataset.zip inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-0.pt inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-20.pt inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-40.pt inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-60.pt inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-80.pt inflating: checkpoints/cifar_10_mislabelled_dataset/checkpoint-100.pt
我们将检查点列表 mislabelled_dataset_checkpoint_paths 定义为训练中所有保存的检查点。
mislabelled_dataset_checkpoint_paths = glob.glob(os.path.join(mislabelled_dataset_checkpoints_dir, "*.pt"))
为了本教程的演示,为了节省时间/内存,我们将仅考虑最后一个全连接层中的梯度,并且不会使用 TracInCP。因为我们要计算自影响分数,我们应该使用 TracInCPFast(回想一下 TracInCPFastRandProj 不应在自影响模式下使用)。
TracInCPFast 实例¶我们现在定义 TracInCPFast 实例。初始化应该是瞬时的,因为没有进行预处理。注意我们使用了标注错误的数据集以及与之对应的训练检查点,即 mislabelled_dataset 和 mislabelled_dataset_checkpoint_paths。
tracin_cp_fast = TracInCPFast(
model=net,
final_fc_layer=list(net.children())[-1],
train_dataset=mislabelled_dataset,
checkpoints=mislabelled_dataset_checkpoint_paths,
checkpoints_load_func=checkpoints_load_func,
loss_fn=nn.CrossEntropyLoss(reduction="sum"),
batch_size=2048,
)
我们现在可以计算 incorrect_dataset 的自影响分数。注意该函数调用没有参数,因为 incorrect_dataset 在 tracin_cp_fast 初始化期间已经加载。这应该需要几分钟时间。
start_time = datetime.datetime.now()
self_influence_scores = tracin_cp_fast.self_influence()
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print('computed self influence scores for %d examples in %.2f minutes' % (len(self_influence_scores), total_minutes))
computed self influence scores for 50000 examples in 0.50 minutes
为了通过计算自影响分数来评估 TracInCP 识别标注错误样本的能力,我们想要回答这个问题:“按自影响分数(降序)对样本进行排序,是否倾向于将标注错误的样本排在标注正确的样本之前?”我们可以通过显示使用自影响分数识别标注错误样本的 ROC 曲线来回答这个问题。这类似于我们使用预测的正概率来识别真阳性时显示 ROC 曲线来衡量性能。回想一下,我们已经计算了 is_mislabelled,即每个样本是否标注错误的地面真值。下面是 ROC 曲线:
fpr, tpr, _ = roc_curve(is_mislabelled, self_influence_scores)
fig, ax = plt.subplots()
ax.plot(fpr, tpr)
fontsize = 10
ax.set_ylabel("TPR (proportion of mislabelled examples found)", fontsize=fontsize)
ax.set_xlabel("FPR (proportion of correctly-labelled examples examined)", fontsize=fontsize)
ax.set_title("ROC curve when identifying mislabelled examples using self influence scores")
fig.show()
我们看到,通过自影响分数确定优先级倾向于将标注错误的样本排在标注正确的样本之前。请注意,如果我们使用了更好的模型(如 resnet),性能(即 ROC 曲线下面积)会好得多;本教程的目的仅是演示 TracInCP 的用法。