【Pytorch】迁移学习(Transfer Learning)「建议收藏」

(120) 2024-05-25 16:01:01

1.什么是迁移学习

迁移学习是一种机器学习技术,其中一个模型已在一个任务上训练好,并且该模型的经验可以用来更快地训练另一个相似的任务。这种技术的目的是为了减少在新任务上的训练时间,因为训练模型需要大量的数据和时间。

2.常见迁移学习方式

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全链接层,仅训练最后一个全链接层

3.例子

以kaggle中猫狗数据集为例,猫狗数据集
导包

import torch
import torchvision
from torch import nn, optim
from torchvision import transforms, datasets, models
from tqdm import tqdm
import sys

tqdm和sys,如果不需要进度条显示可不导入。
pytorch中迁移学习模型在torchvision.models

  1. vgg16
model = models.vgg16(pretrained=True)  # pretrained=True即为返回在 ImageNet (是数据集)上预训练的模型
for parameter in model.parameters():
    parameter.requires_grad = False	   # 冻结了所有层(参数不会更新)

此时模型已导入。可用model.buffer或直接print(model)查看模型。
【Pytorch】迁移学习(Transfer Learning)「建议收藏」 (https://mushiming.com/)  第1张
猫狗数据集为二分类,所以最后一层全连接层输出应为2,修改为:

model.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)

修改过后该全连接层不再被冻结,参数可被更新。
可用以下方式查看模型各层的冻结情况:

for m, n in model.named_parameters():
    print(m, n.requires_grad)

【Pytorch】迁移学习(Transfer Learning)「建议收藏」 (https://mushiming.com/)  第2张
最后就是训练,训练最后一层全连接层的参数。

optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001)

save_path = './save_path/vgg16_1.pth'

epochs = 15
train_steps = len(train_dataloader)
val_length = len(val_dataset)
best_acc = 0.0

for epoch in range(epochs):
    running_loss = 0.0
    model.train()
    train_bar = tqdm(train_dataloader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = loss_function(output, labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss
        train_bar.desc = 'epoch:{}/{} loss:{:.3f}'.format(epoch+1, epochs, loss)
        
    model.eval()
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(val_dataloader, file=sys.stdout)
        for data in val_bar:
            val_images, val_labels = data
            val_output = model(val_images.to(device))
            predict = torch.max(val_output, 1)[1]
            acc += torch.eq(predict, val_labels.to(device)).sum().item()
    
        val_accurate = acc/val_length
        print('epoch:{}/{} train_loss:{:.3f} val_accurate:{:.3f}'.format(epoch+1, epochs, running_loss/train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

一共训练了15个epoch,一个epoch差不多43s,训练过程:
【Pytorch】迁移学习(Transfer Learning)「建议收藏」 (https://mushiming.com/)  第3张
可以看到验证集准确率差不多95%。

还可参考这篇文章

THE END

发表回复