利用pytorch框架实现minist手写识别

如何用Pytorch实现mnist手写识别

起因是Geek的面试题刚好出了如何实践一个利用Pytorch来完成手写项目的试题,所以在这里详细写一下我的经验和过程

先来导入我们所需要的包

import torch
import torchvision
from torch.utils.data import DataLoader

导入我们需要的数据集以及超参数

设计好超参数以及随机种子

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1 #为了可复现性,我们设计一个随机数种子,保证生成的随机数可复现
torch.manual_seed(random_seed)
<torch._C.Generator at 0x253f5eb5570>

接下来我们引入我们需要的数据集

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_test, shuffle=True)
# 这里解释一下内部的几个参数 train指定了是否为训练集或者测试集,download选择是否下载该数据集,transform制定了对数据进行的变换方式,batch_size指训练的批次,shuffle是指是否将数据打乱
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST\raw\train-images-idx3-ubyte.gz to ./data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw

成功引入数据后将会在当前目录下创建一个名为data的文件夹,里面就包括我们所需要的数据

接下来对数据做一个预览

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
tensor([3, 9, 4, 9, 9, 0, 8, 3, 1, 2, 3, 9, 1, 3, 6, 6, 4, 4, 9, 7, 3, 7, 6, 3,
        4, 8, 4, 6, 8, 6, 1, 1, 1, 1, 0, 0, 1, 8, 4, 0, 1, 2, 7, 9, 3, 2, 3, 8,
        3, 2, 0, 4, 6, 6, 5, 5, 3, 0, 3, 7, 2, 4, 1, 6, 7, 5, 4, 1, 0, 8, 5, 9,
        0, 9, 6, 1, 8, 0, 9, 3, 5, 7, 8, 5, 6, 4, 2, 2, 2, 1, 8, 4, 4, 2, 1, 5,
        9, 3, 7, 0, 4, 1, 7, 2, 2, 6, 5, 1, 2, 6, 1, 3, 1, 2, 7, 4, 1, 3, 2, 9,
        3, 7, 2, 7, 4, 7, 0, 0, 9, 0, 9, 5, 1, 6, 9, 8, 1, 4, 6, 7, 5, 5, 6, 4,
        5, 4, 0, 7, 4, 1, 3, 8, 2, 6, 2, 2, 8, 4, 1, 9, 6, 7, 5, 0, 7, 4, 6, 2,
        6, 8, 7, 9, 8, 0, 7, 4, 2, 4, 9, 0, 4, 6, 5, 6, 6, 4, 4, 8, 3, 8, 0, 9,
        5, 6, 3, 3, 8, 4, 4, 6, 7, 8, 5, 3, 4, 3, 3, 3, 3, 1, 8, 8, 9, 1, 0, 1,
        4, 1, 0, 2, 0, 6, 5, 1, 5, 8, 8, 1, 3, 9, 4, 5, 9, 2, 5, 6, 5, 8, 2, 7,
        4, 1, 6, 8, 7, 9, 3, 3, 4, 4, 5, 4, 9, 8, 2, 9, 1, 6, 7, 1, 0, 5, 1, 9,
        2, 1, 5, 2, 7, 9, 0, 0, 6, 0, 8, 7, 2, 6, 1, 7, 6, 1, 2, 1, 6, 3, 1, 4,
        1, 7, 5, 9, 3, 8, 6, 3, 8, 7, 4, 2, 3, 2, 4, 3, 5, 1, 2, 4, 0, 9, 4, 0,
        5, 0, 7, 3, 3, 1, 8, 4, 1, 2, 2, 6, 0, 2, 1, 6, 3, 6, 2, 1, 7, 3, 9, 9,
        0, 7, 6, 6, 8, 0, 1, 1, 9, 9, 6, 9, 4, 7, 8, 7, 8, 5, 3, 2, 4, 4, 3, 3,
        1, 9, 5, 3, 1, 8, 4, 7, 6, 0, 5, 9, 7, 5, 4, 0, 3, 6, 0, 0, 5, 9, 7, 4,
        7, 7, 3, 3, 8, 0, 5, 8, 6, 3, 4, 9, 6, 2, 2, 9, 3, 9, 4, 3, 7, 8, 0, 3,
        7, 5, 1, 7, 0, 1, 7, 1, 6, 8, 9, 0, 2, 0, 0, 2, 8, 5, 6, 2, 1, 6, 4, 2,
        1, 0, 7, 8, 7, 9, 8, 3, 1, 5, 2, 5, 0, 3, 8, 4, 1, 8, 0, 2, 4, 4, 0, 0,
        4, 2, 8, 7, 5, 8, 5, 3, 7, 3, 3, 3, 1, 2, 4, 5, 9, 8, 2, 7, 6, 8, 6, 1,
        3, 4, 0, 3, 7, 0, 3, 2, 9, 6, 2, 1, 7, 8, 1, 6, 9, 8, 6, 8, 7, 0, 5, 2,
        5, 4, 4, 8, 6, 6, 8, 6, 7, 0, 9, 4, 4, 9, 7, 5, 7, 5, 9, 0, 8, 3, 9, 8,
        1, 3, 3, 1, 0, 0, 1, 2, 1, 7, 0, 7, 0, 9, 8, 3, 2, 2, 0, 3, 0, 3, 0, 3,
        6, 1, 5, 8, 8, 7, 4, 5, 2, 6, 7, 8, 4, 8, 3, 9, 2, 6, 8, 7, 2, 3, 0, 4,
        9, 7, 7, 5, 0, 3, 6, 2, 9, 5, 9, 3, 7, 3, 6, 0, 9, 6, 5, 1, 6, 3, 9, 0,
        9, 8, 4, 1, 6, 4, 1, 9, 1, 4, 6, 9, 4, 4, 7, 2, 3, 5, 9, 9, 2, 3, 0, 9,
        9, 0, 0, 2, 4, 5, 6, 4, 7, 9, 0, 9, 0, 8, 3, 3, 1, 6, 0, 8, 3, 9, 8, 0,
        3, 9, 7, 4, 8, 0, 9, 9, 7, 4, 9, 0, 1, 9, 0, 3, 3, 6, 8, 2, 6, 0, 0, 8,
        8, 7, 9, 7, 7, 6, 1, 6, 0, 1, 9, 9, 5, 6, 9, 6, 2, 2, 1, 1, 6, 0, 7, 8,
        4, 2, 4, 7, 0, 8, 9, 5, 9, 6, 9, 1, 2, 2, 8, 1, 3, 6, 2, 0, 5, 1, 1, 6,
        8, 7, 0, 6, 2, 5, 6, 1, 9, 8, 6, 8, 5, 1, 7, 3, 2, 0, 9, 2, 2, 5, 5, 3,
        8, 1, 4, 9, 7, 7, 1, 8, 2, 8, 1, 1, 7, 4, 1, 6, 9, 2, 0, 3, 4, 2, 0, 5,
        7, 5, 7, 2, 5, 6, 8, 2, 6, 7, 3, 1, 6, 2, 4, 2, 0, 2, 5, 4, 1, 1, 5, 7,
        4, 5, 2, 0, 8, 0, 4, 7, 3, 8, 8, 2, 3, 2, 6, 5, 1, 0, 4, 1, 3, 9, 0, 2,
        8, 9, 1, 2, 0, 8, 7, 4, 4, 0, 7, 1, 0, 8, 5, 3, 3, 3, 4, 0, 7, 4, 4, 5,
        0, 2, 2, 7, 8, 1, 2, 8, 0, 3, 5, 4, 6, 7, 9, 0, 0, 7, 3, 9, 4, 9, 5, 1,
        4, 4, 9, 8, 7, 3, 9, 0, 8, 7, 0, 8, 4, 5, 8, 1, 7, 4, 2, 1, 1, 5, 1, 1,
        5, 5, 6, 4, 3, 7, 3, 3, 3, 0, 4, 7, 9, 0, 3, 9, 8, 0, 7, 8, 7, 0, 0, 3,
        7, 7, 6, 8, 8, 9, 8, 0, 9, 4, 9, 3, 8, 6, 8, 2, 0, 4, 8, 4, 4, 4, 6, 8,
        3, 6, 3, 5, 1, 9, 5, 4, 3, 8, 3, 1, 2, 3, 1, 1, 9, 9, 5, 5, 0, 4, 1, 6,
        9, 4, 8, 6, 8, 1, 7, 2, 6, 9, 8, 2, 7, 2, 2, 4, 8, 9, 3, 6, 2, 2, 7, 8,
        5, 7, 2, 0, 0, 2, 3, 1, 1, 5, 8, 5, 3, 9, 7, 6])
torch.Size([1000, 1, 28, 28])

观察显示我们得到了批次为1000 的28*28的单通道图片
利用matplotlib来可视化我们的数据

import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

png

接下来我们来构造我们的神经网络

我们将使用两个2d卷积层,然后是两个全连接(或线性)层。作为激活函数,我们将选择整流线性单元(简称ReLUs),作为正则化的手段,我们将使用两个dropout层,避免出现过拟合的情况

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

接下来初始化网络和优化器

network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

开始模型训练

首先,我们要确保我们的网络处于训练模式。然后,每个epoch对所有训练数据进行一次迭代。加载单独批次由DataLoader处理。

我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数。
我们还将使用一些打印输出来跟踪进度。为了在以后创建一个良好的训练曲线,我们还创建了两个列表来节省训练和测试损失。在x轴上,我们希望显示网络在训练期间看到的训练示例的数量。

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

在开始训练之前,我们将运行一次测试循环,看看仅使用随机初始化的网络参数可以获得多大的精度/损失

def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
                (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
            torch.save(network.state_dict(), './model.pth')
            torch.save(optimizer.state_dict(), './optimizer.pth')
 
train(1)
c:\Users\lima\anaconda3\envs\pytorch\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.


Train Epoch: 1 [0/60000 (0%)]	Loss: 2.303384
Train Epoch: 1 [640/60000 (1%)]	Loss: 2.279411
Train Epoch: 1 [1280/60000 (2%)]	Loss: 2.289252
Train Epoch: 1 [1920/60000 (3%)]	Loss: 2.262080
Train Epoch: 1 [2560/60000 (4%)]	Loss: 2.253610
Train Epoch: 1 [3200/60000 (5%)]	Loss: 2.227270
Train Epoch: 1 [3840/60000 (6%)]	Loss: 2.195740
Train Epoch: 1 [4480/60000 (7%)]	Loss: 2.184438
Train Epoch: 1 [5120/60000 (9%)]	Loss: 2.071461
Train Epoch: 1 [5760/60000 (10%)]	Loss: 1.949601
Train Epoch: 1 [6400/60000 (11%)]	Loss: 1.779031
Train Epoch: 1 [7040/60000 (12%)]	Loss: 1.790707
Train Epoch: 1 [7680/60000 (13%)]	Loss: 1.765616
Train Epoch: 1 [8320/60000 (14%)]	Loss: 1.678705
Train Epoch: 1 [8960/60000 (15%)]	Loss: 1.477044
Train Epoch: 1 [9600/60000 (16%)]	Loss: 1.439807
Train Epoch: 1 [10240/60000 (17%)]	Loss: 1.124465
Train Epoch: 1 [10880/60000 (18%)]	Loss: 1.139860
Train Epoch: 1 [11520/60000 (19%)]	Loss: 1.357752
Train Epoch: 1 [12160/60000 (20%)]	Loss: 0.921102
Train Epoch: 1 [12800/60000 (21%)]	Loss: 0.966025
Train Epoch: 1 [13440/60000 (22%)]	Loss: 1.021417
Train Epoch: 1 [14080/60000 (23%)]	Loss: 0.974782
Train Epoch: 1 [14720/60000 (25%)]	Loss: 0.845231
Train Epoch: 1 [15360/60000 (26%)]	Loss: 1.010778
Train Epoch: 1 [16000/60000 (27%)]	Loss: 0.979917
Train Epoch: 1 [16640/60000 (28%)]	Loss: 0.668632
Train Epoch: 1 [17280/60000 (29%)]	Loss: 0.940822
Train Epoch: 1 [17920/60000 (30%)]	Loss: 1.050762
Train Epoch: 1 [18560/60000 (31%)]	Loss: 0.995004
Train Epoch: 1 [19200/60000 (32%)]	Loss: 1.021758
Train Epoch: 1 [19840/60000 (33%)]	Loss: 0.885002
Train Epoch: 1 [20480/60000 (34%)]	Loss: 0.804499
Train Epoch: 1 [21120/60000 (35%)]	Loss: 0.603203
Train Epoch: 1 [21760/60000 (36%)]	Loss: 0.933655
Train Epoch: 1 [22400/60000 (37%)]	Loss: 0.972324
Train Epoch: 1 [23040/60000 (38%)]	Loss: 0.844677
Train Epoch: 1 [23680/60000 (39%)]	Loss: 0.515927
Train Epoch: 1 [24320/60000 (41%)]	Loss: 0.779014
Train Epoch: 1 [24960/60000 (42%)]	Loss: 0.603919
Train Epoch: 1 [25600/60000 (43%)]	Loss: 0.646297
Train Epoch: 1 [26240/60000 (44%)]	Loss: 0.468981
Train Epoch: 1 [26880/60000 (45%)]	Loss: 0.754982
Train Epoch: 1 [27520/60000 (46%)]	Loss: 0.584946
Train Epoch: 1 [28160/60000 (47%)]	Loss: 0.543073
Train Epoch: 1 [28800/60000 (48%)]	Loss: 0.707626
Train Epoch: 1 [29440/60000 (49%)]	Loss: 0.561938
Train Epoch: 1 [30080/60000 (50%)]	Loss: 0.527597
Train Epoch: 1 [30720/60000 (51%)]	Loss: 0.447929
Train Epoch: 1 [31360/60000 (52%)]	Loss: 0.520991
Train Epoch: 1 [32000/60000 (53%)]	Loss: 0.625621
Train Epoch: 1 [32640/60000 (54%)]	Loss: 0.627771
Train Epoch: 1 [33280/60000 (55%)]	Loss: 0.777585
Train Epoch: 1 [33920/60000 (57%)]	Loss: 0.587220
Train Epoch: 1 [34560/60000 (58%)]	Loss: 0.454107
Train Epoch: 1 [35200/60000 (59%)]	Loss: 0.689290
Train Epoch: 1 [35840/60000 (60%)]	Loss: 0.623859
Train Epoch: 1 [36480/60000 (61%)]	Loss: 0.580610
Train Epoch: 1 [37120/60000 (62%)]	Loss: 0.513336
Train Epoch: 1 [37760/60000 (63%)]	Loss: 0.816634
Train Epoch: 1 [38400/60000 (64%)]	Loss: 0.755324
Train Epoch: 1 [39040/60000 (65%)]	Loss: 0.625681
Train Epoch: 1 [39680/60000 (66%)]	Loss: 0.568908
Train Epoch: 1 [40320/60000 (67%)]	Loss: 0.574059
Train Epoch: 1 [40960/60000 (68%)]	Loss: 0.756992
Train Epoch: 1 [41600/60000 (69%)]	Loss: 0.880263
Train Epoch: 1 [42240/60000 (70%)]	Loss: 0.425132
Train Epoch: 1 [42880/60000 (71%)]	Loss: 0.424116
Train Epoch: 1 [43520/60000 (72%)]	Loss: 0.527854
Train Epoch: 1 [44160/60000 (74%)]	Loss: 0.444930
Train Epoch: 1 [44800/60000 (75%)]	Loss: 0.356500
Train Epoch: 1 [45440/60000 (76%)]	Loss: 0.533110
Train Epoch: 1 [46080/60000 (77%)]	Loss: 0.425781
Train Epoch: 1 [46720/60000 (78%)]	Loss: 0.449958
Train Epoch: 1 [47360/60000 (79%)]	Loss: 0.721345
Train Epoch: 1 [48000/60000 (80%)]	Loss: 0.474873
Train Epoch: 1 [48640/60000 (81%)]	Loss: 0.309928
Train Epoch: 1 [49280/60000 (82%)]	Loss: 0.621575
Train Epoch: 1 [49920/60000 (83%)]	Loss: 0.388634
Train Epoch: 1 [50560/60000 (84%)]	Loss: 0.309106
Train Epoch: 1 [51200/60000 (85%)]	Loss: 0.771013
Train Epoch: 1 [51840/60000 (86%)]	Loss: 0.643398
Train Epoch: 1 [52480/60000 (87%)]	Loss: 0.478030
Train Epoch: 1 [53120/60000 (88%)]	Loss: 0.589063
Train Epoch: 1 [53760/60000 (90%)]	Loss: 0.332433
Train Epoch: 1 [54400/60000 (91%)]	Loss: 0.407279
Train Epoch: 1 [55040/60000 (92%)]	Loss: 0.516079
Train Epoch: 1 [55680/60000 (93%)]	Loss: 0.411782
Train Epoch: 1 [56320/60000 (94%)]	Loss: 0.695302
Train Epoch: 1 [56960/60000 (95%)]	Loss: 0.539544
Train Epoch: 1 [57600/60000 (96%)]	Loss: 0.537560
Train Epoch: 1 [58240/60000 (97%)]	Loss: 0.364238
Train Epoch: 1 [58880/60000 (98%)]	Loss: 0.593189
Train Epoch: 1 [59520/60000 (99%)]	Loss: 0.496419

现在进入测试循环。在这里,我们总结了测试损失,并跟踪正确分类的数字来计算网络的精度。

def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
 
test()
c:\Users\lima\anaconda3\envs\pytorch\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
c:\Users\lima\anaconda3\envs\pytorch\lib\site-packages\torch\nn\_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
  warnings.warn(warning.format(ret))



Test set: Avg. loss: 0.1864, Accuracy: 9444/10000 (94%)

使用上下文管理器with torch.no_grad(),我们可以避免将生成网络输出的计算结果存储在计算图中。

# test()  # 不加这个,后面画图就会报错:x and y must be the same size
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()
c:\Users\lima\anaconda3\envs\pytorch\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.


Train Epoch: 1 [0/60000 (0%)]	Loss: 0.435152
Train Epoch: 1 [640/60000 (1%)]	Loss: 0.811450
Train Epoch: 1 [1280/60000 (2%)]	Loss: 0.483953
Train Epoch: 1 [1920/60000 (3%)]	Loss: 0.400930
Train Epoch: 1 [2560/60000 (4%)]	Loss: 0.475782
Train Epoch: 1 [3200/60000 (5%)]	Loss: 0.443342
Train Epoch: 1 [3840/60000 (6%)]	Loss: 0.473539
Train Epoch: 1 [4480/60000 (7%)]	Loss: 0.291012
Train Epoch: 1 [5120/60000 (9%)]	Loss: 0.287207
Train Epoch: 1 [5760/60000 (10%)]	Loss: 0.370008
Train Epoch: 1 [6400/60000 (11%)]	Loss: 0.304390
Train Epoch: 1 [7040/60000 (12%)]	Loss: 0.334497
Train Epoch: 1 [7680/60000 (13%)]	Loss: 0.412974
Train Epoch: 1 [8320/60000 (14%)]	Loss: 0.557167
Train Epoch: 1 [8960/60000 (15%)]	Loss: 0.434179
Train Epoch: 1 [9600/60000 (16%)]	Loss: 0.396031
Train Epoch: 1 [10240/60000 (17%)]	Loss: 0.386125
Train Epoch: 1 [10880/60000 (18%)]	Loss: 0.337955
Train Epoch: 1 [11520/60000 (19%)]	Loss: 0.353032
Train Epoch: 1 [12160/60000 (20%)]	Loss: 0.246413
Train Epoch: 1 [12800/60000 (21%)]	Loss: 0.563054
Train Epoch: 1 [13440/60000 (22%)]	Loss: 0.582254
Train Epoch: 1 [14080/60000 (23%)]	Loss: 0.465870
Train Epoch: 1 [14720/60000 (25%)]	Loss: 0.412476
Train Epoch: 1 [15360/60000 (26%)]	Loss: 0.370585
Train Epoch: 1 [16000/60000 (27%)]	Loss: 0.410284
Train Epoch: 1 [16640/60000 (28%)]	Loss: 0.320114
Train Epoch: 1 [17280/60000 (29%)]	Loss: 0.483046
Train Epoch: 1 [17920/60000 (30%)]	Loss: 0.462899
Train Epoch: 1 [18560/60000 (31%)]	Loss: 0.414641
Train Epoch: 1 [19200/60000 (32%)]	Loss: 0.371618
Train Epoch: 1 [19840/60000 (33%)]	Loss: 0.542934
Train Epoch: 1 [20480/60000 (34%)]	Loss: 0.330289
Train Epoch: 1 [21120/60000 (35%)]	Loss: 0.658468
Train Epoch: 1 [21760/60000 (36%)]	Loss: 0.599104
Train Epoch: 1 [22400/60000 (37%)]	Loss: 0.312793
Train Epoch: 1 [23040/60000 (38%)]	Loss: 0.420758
Train Epoch: 1 [23680/60000 (39%)]	Loss: 0.403395
Train Epoch: 1 [24320/60000 (41%)]	Loss: 0.345512
Train Epoch: 1 [24960/60000 (42%)]	Loss: 0.538843
Train Epoch: 1 [25600/60000 (43%)]	Loss: 0.648428
Train Epoch: 1 [26240/60000 (44%)]	Loss: 0.255779
Train Epoch: 1 [26880/60000 (45%)]	Loss: 0.480605
Train Epoch: 1 [27520/60000 (46%)]	Loss: 0.433623
Train Epoch: 1 [28160/60000 (47%)]	Loss: 0.499512
Train Epoch: 1 [28800/60000 (48%)]	Loss: 0.465719
Train Epoch: 1 [29440/60000 (49%)]	Loss: 0.497530
Train Epoch: 1 [30080/60000 (50%)]	Loss: 0.538419
Train Epoch: 1 [30720/60000 (51%)]	Loss: 0.373448
Train Epoch: 1 [31360/60000 (52%)]	Loss: 0.156499
Train Epoch: 1 [32000/60000 (53%)]	Loss: 0.378104
Train Epoch: 1 [32640/60000 (54%)]	Loss: 0.692635
Train Epoch: 1 [33280/60000 (55%)]	Loss: 0.356678
Train Epoch: 1 [33920/60000 (57%)]	Loss: 0.365668
Train Epoch: 1 [34560/60000 (58%)]	Loss: 0.262644
Train Epoch: 1 [35200/60000 (59%)]	Loss: 0.449404
Train Epoch: 1 [35840/60000 (60%)]	Loss: 0.390799
Train Epoch: 1 [36480/60000 (61%)]	Loss: 0.303179
Train Epoch: 1 [37120/60000 (62%)]	Loss: 0.298954
Train Epoch: 1 [37760/60000 (63%)]	Loss: 0.483630
Train Epoch: 1 [38400/60000 (64%)]	Loss: 0.463550
Train Epoch: 1 [39040/60000 (65%)]	Loss: 0.532306
Train Epoch: 1 [39680/60000 (66%)]	Loss: 0.490324
Train Epoch: 1 [40320/60000 (67%)]	Loss: 0.393647
Train Epoch: 1 [40960/60000 (68%)]	Loss: 0.687885
Train Epoch: 1 [41600/60000 (69%)]	Loss: 0.219126
Train Epoch: 1 [42240/60000 (70%)]	Loss: 0.340450
Train Epoch: 1 [42880/60000 (71%)]	Loss: 0.300381
Train Epoch: 1 [43520/60000 (72%)]	Loss: 0.320121
Train Epoch: 1 [44160/60000 (74%)]	Loss: 0.501960
Train Epoch: 1 [44800/60000 (75%)]	Loss: 0.337299
Train Epoch: 1 [45440/60000 (76%)]	Loss: 0.413399
Train Epoch: 1 [46080/60000 (77%)]	Loss: 0.167400
Train Epoch: 1 [46720/60000 (78%)]	Loss: 0.570557
Train Epoch: 1 [47360/60000 (79%)]	Loss: 0.427059
Train Epoch: 1 [48000/60000 (80%)]	Loss: 0.403576
Train Epoch: 1 [48640/60000 (81%)]	Loss: 0.333451
Train Epoch: 1 [49280/60000 (82%)]	Loss: 0.334460
Train Epoch: 1 [49920/60000 (83%)]	Loss: 0.438894
Train Epoch: 1 [50560/60000 (84%)]	Loss: 0.275873
Train Epoch: 1 [51200/60000 (85%)]	Loss: 0.354788
Train Epoch: 1 [51840/60000 (86%)]	Loss: 0.460378
Train Epoch: 1 [52480/60000 (87%)]	Loss: 0.414216
Train Epoch: 1 [53120/60000 (88%)]	Loss: 0.540228
Train Epoch: 1 [53760/60000 (90%)]	Loss: 0.367723
Train Epoch: 1 [54400/60000 (91%)]	Loss: 0.388515
Train Epoch: 1 [55040/60000 (92%)]	Loss: 0.330025
Train Epoch: 1 [55680/60000 (93%)]	Loss: 0.166161
Train Epoch: 1 [56320/60000 (94%)]	Loss: 0.289000
Train Epoch: 1 [56960/60000 (95%)]	Loss: 0.281897
Train Epoch: 1 [57600/60000 (96%)]	Loss: 0.667651
Train Epoch: 1 [58240/60000 (97%)]	Loss: 0.357329
Train Epoch: 1 [58880/60000 (98%)]	Loss: 0.335937
Train Epoch: 1 [59520/60000 (99%)]	Loss: 0.462703

Test set: Avg. loss: 0.1199, Accuracy: 9628/10000 (96%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.258914
Train Epoch: 2 [640/60000 (1%)]	Loss: 0.394483
Train Epoch: 2 [1280/60000 (2%)]	Loss: 0.169133
Train Epoch: 2 [1920/60000 (3%)]	Loss: 0.299415
Train Epoch: 2 [2560/60000 (4%)]	Loss: 0.436715
Train Epoch: 2 [3200/60000 (5%)]	Loss: 0.323161
Train Epoch: 2 [3840/60000 (6%)]	Loss: 0.441950
Train Epoch: 2 [4480/60000 (7%)]	Loss: 0.244519
Train Epoch: 2 [5120/60000 (9%)]	Loss: 0.342655
Train Epoch: 2 [5760/60000 (10%)]	Loss: 0.624689
Train Epoch: 2 [6400/60000 (11%)]	Loss: 0.246923
Train Epoch: 2 [7040/60000 (12%)]	Loss: 0.365575
Train Epoch: 2 [7680/60000 (13%)]	Loss: 0.142231
Train Epoch: 2 [8320/60000 (14%)]	Loss: 0.288659
Train Epoch: 2 [8960/60000 (15%)]	Loss: 0.451670
Train Epoch: 2 [9600/60000 (16%)]	Loss: 0.256040
Train Epoch: 2 [10240/60000 (17%)]	Loss: 0.261893
Train Epoch: 2 [10880/60000 (18%)]	Loss: 0.357489
Train Epoch: 2 [11520/60000 (19%)]	Loss: 0.270580
Train Epoch: 2 [12160/60000 (20%)]	Loss: 0.197454
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.311411
Train Epoch: 2 [13440/60000 (22%)]	Loss: 0.264021
Train Epoch: 2 [14080/60000 (23%)]	Loss: 0.484878
Train Epoch: 2 [14720/60000 (25%)]	Loss: 0.219035
Train Epoch: 2 [15360/60000 (26%)]	Loss: 0.356200
Train Epoch: 2 [16000/60000 (27%)]	Loss: 0.204312
Train Epoch: 2 [16640/60000 (28%)]	Loss: 0.309495
Train Epoch: 2 [17280/60000 (29%)]	Loss: 0.170716
Train Epoch: 2 [17920/60000 (30%)]	Loss: 0.315290
Train Epoch: 2 [18560/60000 (31%)]	Loss: 0.367733
Train Epoch: 2 [19200/60000 (32%)]	Loss: 0.239908
Train Epoch: 2 [19840/60000 (33%)]	Loss: 0.179511
Train Epoch: 2 [20480/60000 (34%)]	Loss: 0.303227
Train Epoch: 2 [21120/60000 (35%)]	Loss: 0.266887
Train Epoch: 2 [21760/60000 (36%)]	Loss: 0.256416
Train Epoch: 2 [22400/60000 (37%)]	Loss: 0.278098
Train Epoch: 2 [23040/60000 (38%)]	Loss: 0.229017
Train Epoch: 2 [23680/60000 (39%)]	Loss: 0.364240
Train Epoch: 2 [24320/60000 (41%)]	Loss: 0.402158
Train Epoch: 2 [24960/60000 (42%)]	Loss: 0.324799
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.326762
Train Epoch: 2 [26240/60000 (44%)]	Loss: 0.177931
Train Epoch: 2 [26880/60000 (45%)]	Loss: 0.179780
Train Epoch: 2 [27520/60000 (46%)]	Loss: 0.379914
Train Epoch: 2 [28160/60000 (47%)]	Loss: 0.342394
Train Epoch: 2 [28800/60000 (48%)]	Loss: 0.242255
Train Epoch: 2 [29440/60000 (49%)]	Loss: 0.631458
Train Epoch: 2 [30080/60000 (50%)]	Loss: 0.168572
Train Epoch: 2 [30720/60000 (51%)]	Loss: 0.332928
Train Epoch: 2 [31360/60000 (52%)]	Loss: 0.123759
Train Epoch: 2 [32000/60000 (53%)]	Loss: 0.289734
Train Epoch: 2 [32640/60000 (54%)]	Loss: 0.434450
Train Epoch: 2 [33280/60000 (55%)]	Loss: 0.186277
Train Epoch: 2 [33920/60000 (57%)]	Loss: 0.199597
Train Epoch: 2 [34560/60000 (58%)]	Loss: 0.455450
Train Epoch: 2 [35200/60000 (59%)]	Loss: 0.545573
Train Epoch: 2 [35840/60000 (60%)]	Loss: 0.301150
Train Epoch: 2 [36480/60000 (61%)]	Loss: 0.140879
Train Epoch: 2 [37120/60000 (62%)]	Loss: 0.162383
Train Epoch: 2 [37760/60000 (63%)]	Loss: 0.656334
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.382549
Train Epoch: 2 [39040/60000 (65%)]	Loss: 0.368554
Train Epoch: 2 [39680/60000 (66%)]	Loss: 0.399492
Train Epoch: 2 [40320/60000 (67%)]	Loss: 0.272527
Train Epoch: 2 [40960/60000 (68%)]	Loss: 0.266827
Train Epoch: 2 [41600/60000 (69%)]	Loss: 0.218565
Train Epoch: 2 [42240/60000 (70%)]	Loss: 0.255175
Train Epoch: 2 [42880/60000 (71%)]	Loss: 0.223380
Train Epoch: 2 [43520/60000 (72%)]	Loss: 0.514253
Train Epoch: 2 [44160/60000 (74%)]	Loss: 0.245939
Train Epoch: 2 [44800/60000 (75%)]	Loss: 0.368585
Train Epoch: 2 [45440/60000 (76%)]	Loss: 0.139814
Train Epoch: 2 [46080/60000 (77%)]	Loss: 0.212112
Train Epoch: 2 [46720/60000 (78%)]	Loss: 0.231429
Train Epoch: 2 [47360/60000 (79%)]	Loss: 0.358406
Train Epoch: 2 [48000/60000 (80%)]	Loss: 0.195940
Train Epoch: 2 [48640/60000 (81%)]	Loss: 0.177920
Train Epoch: 2 [49280/60000 (82%)]	Loss: 0.407587
Train Epoch: 2 [49920/60000 (83%)]	Loss: 0.317657
Train Epoch: 2 [50560/60000 (84%)]	Loss: 0.367594
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.304408
Train Epoch: 2 [51840/60000 (86%)]	Loss: 0.229558
Train Epoch: 2 [52480/60000 (87%)]	Loss: 0.119620
Train Epoch: 2 [53120/60000 (88%)]	Loss: 0.418236
Train Epoch: 2 [53760/60000 (90%)]	Loss: 0.281973
Train Epoch: 2 [54400/60000 (91%)]	Loss: 0.316667
Train Epoch: 2 [55040/60000 (92%)]	Loss: 0.300849
Train Epoch: 2 [55680/60000 (93%)]	Loss: 0.210434
Train Epoch: 2 [56320/60000 (94%)]	Loss: 0.373901
Train Epoch: 2 [56960/60000 (95%)]	Loss: 0.296396
Train Epoch: 2 [57600/60000 (96%)]	Loss: 0.221576
Train Epoch: 2 [58240/60000 (97%)]	Loss: 0.136145
Train Epoch: 2 [58880/60000 (98%)]	Loss: 0.312394
Train Epoch: 2 [59520/60000 (99%)]	Loss: 0.218397

Test set: Avg. loss: 0.0922, Accuracy: 9713/10000 (97%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.288775
Train Epoch: 3 [640/60000 (1%)]	Loss: 0.524792
Train Epoch: 3 [1280/60000 (2%)]	Loss: 0.411132
Train Epoch: 3 [1920/60000 (3%)]	Loss: 0.230223
Train Epoch: 3 [2560/60000 (4%)]	Loss: 0.334728
Train Epoch: 3 [3200/60000 (5%)]	Loss: 0.283936
Train Epoch: 3 [3840/60000 (6%)]	Loss: 0.288021
Train Epoch: 3 [4480/60000 (7%)]	Loss: 0.264825
Train Epoch: 3 [5120/60000 (9%)]	Loss: 0.208986
Train Epoch: 3 [5760/60000 (10%)]	Loss: 0.284866
Train Epoch: 3 [6400/60000 (11%)]	Loss: 0.448906
Train Epoch: 3 [7040/60000 (12%)]	Loss: 0.365835
Train Epoch: 3 [7680/60000 (13%)]	Loss: 0.146256
Train Epoch: 3 [8320/60000 (14%)]	Loss: 0.362338
Train Epoch: 3 [8960/60000 (15%)]	Loss: 0.184485
Train Epoch: 3 [9600/60000 (16%)]	Loss: 0.198206
Train Epoch: 3 [10240/60000 (17%)]	Loss: 0.455804
Train Epoch: 3 [10880/60000 (18%)]	Loss: 0.194077
Train Epoch: 3 [11520/60000 (19%)]	Loss: 0.228894
Train Epoch: 3 [12160/60000 (20%)]	Loss: 0.294055
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.377451
Train Epoch: 3 [13440/60000 (22%)]	Loss: 0.370325
Train Epoch: 3 [14080/60000 (23%)]	Loss: 0.122095
Train Epoch: 3 [14720/60000 (25%)]	Loss: 0.393088
Train Epoch: 3 [15360/60000 (26%)]	Loss: 0.279382
Train Epoch: 3 [16000/60000 (27%)]	Loss: 0.270269
Train Epoch: 3 [16640/60000 (28%)]	Loss: 0.340930
Train Epoch: 3 [17280/60000 (29%)]	Loss: 0.487929
Train Epoch: 3 [17920/60000 (30%)]	Loss: 0.329661
Train Epoch: 3 [18560/60000 (31%)]	Loss: 0.354316
Train Epoch: 3 [19200/60000 (32%)]	Loss: 0.400725
Train Epoch: 3 [19840/60000 (33%)]	Loss: 0.224443
Train Epoch: 3 [20480/60000 (34%)]	Loss: 0.385668
Train Epoch: 3 [21120/60000 (35%)]	Loss: 0.122148
Train Epoch: 3 [21760/60000 (36%)]	Loss: 0.240514
Train Epoch: 3 [22400/60000 (37%)]	Loss: 0.297372
Train Epoch: 3 [23040/60000 (38%)]	Loss: 0.320687
Train Epoch: 3 [23680/60000 (39%)]	Loss: 0.445879
Train Epoch: 3 [24320/60000 (41%)]	Loss: 0.123584
Train Epoch: 3 [24960/60000 (42%)]	Loss: 0.286526
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.306658
Train Epoch: 3 [26240/60000 (44%)]	Loss: 0.269875
Train Epoch: 3 [26880/60000 (45%)]	Loss: 0.197252
Train Epoch: 3 [27520/60000 (46%)]	Loss: 0.267847
Train Epoch: 3 [28160/60000 (47%)]	Loss: 0.138204
Train Epoch: 3 [28800/60000 (48%)]	Loss: 0.258573
Train Epoch: 3 [29440/60000 (49%)]	Loss: 0.220596
Train Epoch: 3 [30080/60000 (50%)]	Loss: 0.139420
Train Epoch: 3 [30720/60000 (51%)]	Loss: 0.277227
Train Epoch: 3 [31360/60000 (52%)]	Loss: 0.267696
Train Epoch: 3 [32000/60000 (53%)]	Loss: 0.416736
Train Epoch: 3 [32640/60000 (54%)]	Loss: 0.189862
Train Epoch: 3 [33280/60000 (55%)]	Loss: 0.289149
Train Epoch: 3 [33920/60000 (57%)]	Loss: 0.298232
Train Epoch: 3 [34560/60000 (58%)]	Loss: 0.173057
Train Epoch: 3 [35200/60000 (59%)]	Loss: 0.448318
Train Epoch: 3 [35840/60000 (60%)]	Loss: 0.366277
Train Epoch: 3 [36480/60000 (61%)]	Loss: 0.409552
Train Epoch: 3 [37120/60000 (62%)]	Loss: 0.174257
Train Epoch: 3 [37760/60000 (63%)]	Loss: 0.382981
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.156186
Train Epoch: 3 [39040/60000 (65%)]	Loss: 0.218291
Train Epoch: 3 [39680/60000 (66%)]	Loss: 0.175731
Train Epoch: 3 [40320/60000 (67%)]	Loss: 0.206042
Train Epoch: 3 [40960/60000 (68%)]	Loss: 0.675091
Train Epoch: 3 [41600/60000 (69%)]	Loss: 0.350566
Train Epoch: 3 [42240/60000 (70%)]	Loss: 0.302968
Train Epoch: 3 [42880/60000 (71%)]	Loss: 0.205718
Train Epoch: 3 [43520/60000 (72%)]	Loss: 0.324471
Train Epoch: 3 [44160/60000 (74%)]	Loss: 0.254192
Train Epoch: 3 [44800/60000 (75%)]	Loss: 0.379293
Train Epoch: 3 [45440/60000 (76%)]	Loss: 0.218441
Train Epoch: 3 [46080/60000 (77%)]	Loss: 0.383535
Train Epoch: 3 [46720/60000 (78%)]	Loss: 0.381759
Train Epoch: 3 [47360/60000 (79%)]	Loss: 0.425694
Train Epoch: 3 [48000/60000 (80%)]	Loss: 0.130713
Train Epoch: 3 [48640/60000 (81%)]	Loss: 0.303167
Train Epoch: 3 [49280/60000 (82%)]	Loss: 0.141563
Train Epoch: 3 [49920/60000 (83%)]	Loss: 0.453566
Train Epoch: 3 [50560/60000 (84%)]	Loss: 0.241129
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.261887
Train Epoch: 3 [51840/60000 (86%)]	Loss: 0.129367
Train Epoch: 3 [52480/60000 (87%)]	Loss: 0.374807
Train Epoch: 3 [53120/60000 (88%)]	Loss: 0.197192
Train Epoch: 3 [53760/60000 (90%)]	Loss: 0.178797
Train Epoch: 3 [54400/60000 (91%)]	Loss: 0.244141
Train Epoch: 3 [55040/60000 (92%)]	Loss: 0.158185
Train Epoch: 3 [55680/60000 (93%)]	Loss: 0.183351
Train Epoch: 3 [56320/60000 (94%)]	Loss: 0.196326
Train Epoch: 3 [56960/60000 (95%)]	Loss: 0.207487
Train Epoch: 3 [57600/60000 (96%)]	Loss: 0.322931
Train Epoch: 3 [58240/60000 (97%)]	Loss: 0.192208
Train Epoch: 3 [58880/60000 (98%)]	Loss: 0.113755
Train Epoch: 3 [59520/60000 (99%)]	Loss: 0.212530

Test set: Avg. loss: 0.0798, Accuracy: 9747/10000 (97%)

我们来评估一下刚才训练的模型

具体实现方法是画出训练曲线

import matplotlib.pyplot as plt
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()

png

接下来让我们引入一些数据来看看我们的模型输出是否正确

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
    output = network(example_data)
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1][i].item()))
    plt.xticks([])
    plt.yticks([])
plt.show()
c:\Users\lima\anaconda3\envs\pytorch\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.

png

上面做的这个模型看起来输出效果还是相当不错的

接下来我们来聊聊如何进一步优化我们的模型

优化我们的模型

观察我们之前的训练曲线,可以发现训练过程中振荡还是比较明显

  • 可以考虑调大batch_size
  • 增大学习率
    同时,观察发现我们的模型仍然有持续下降的趋势
  • 我们可以考虑,增大epoch,继续训练我们的模型

上传未能正确显示的图片
output_11_01 output_25_01 output_27_11

浙ICP备19012682号