内容目录
- MNIST数据集简介
MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集,它包含70000张手写数字的灰度图片,其中每一张图片包含 28 X 28 个像素点。可以用一个数字数组来表示这张图片。
每一张图片都有对应的标签,也就是图片对应的数字
- 代码实现
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
class Net(nn.Module):
def __init__(self):
super().__init__()
# 定义第一层的输入输出形状,784=28*28是输入维度(mnist的图片均为28*28像素),100是输出维度(传给下一层)
self.fc1 = nn.Linear(784,100)
self.fc2 = nn.Linear(100,10)
# 前向传播函数,forward函数实例化的时候会通过hook自动执行
def forward(self,x):
x = torch.flatten(x,start_dim=1) # flatten打平,因为nn只接受一维数据,尽管打平会损失空间信息
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
max_opochs = 10 # 定义训练多少轮
batch_size = 16 # 每次训练使用多少张图片
# data
# 将原始数据转化为Tensor张量
transform = transforms.Compose([transforms.ToTensor()])
# 导入数据集(训练集和测试集),数据下载到data下,是训练集,本地没数据的话就下载,使用定义好的transform
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True) # shuffle数据打散
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False)
net = Net()
net.to(device)
loss = nn.CrossEntropyLoss()
# optimzer = optim.SGD(net.parameters(),lr=0.0001,momentum=0.9,weight_decay=0.0005) # SGD随机梯度下降,lr学习率
optimzer = optim.SGD(net.parameters(),lr=0.0001)
def train():
acc_num = 0
for epoch in range(max_opochs):
for i,(data,label) in enumerate(train_loader):
data = data.to(device)
label = label.to(device)
optimzer.zero_grad() # 每次计算梯度清零
output = net(data)
Loss = loss(output,label)
Loss.backward()
optimzer.step() # 梯度更新
pred_class = torch.max(output,dim=1)[1]
acc_num += torch.eq(pred_class,label.to(device)).sum().item() # 正确的个数做累加
train_acc = acc_num/len(trainset) # 准确度
# 验证的时候,不更新参数
net.eval()
acc_num = 0.0
best_acc = 0
with torch.no_grad():
for val_data in test_loader:
val_image,val_label = val_data
output = net(val_image.to(device))
predict_y = torch.max(output,dim=1)[1]
acc_num += torch.eq(predict_y,val_label.to(device)).sum().item()
val_acc = acc_num/len(testset)
print(train_acc,val_acc)
if val_acc > best_acc:
torch.save(net.state_dict(),'./mnist.pth') # 把模型计算结果参数,以字典的形式存储起来
best_acc = val_acc
acc_num = 0
train_acc = 0
test_acc = 0
print('done')
train()
0.20858333333333334 0.3182
0.41868333333333335 0.5099
0.5836333333333333 0.6278
0.6711166666666667 0.6881
0.7029833333333333 0.7134
0.7200833333333333 0.7268
0.7287833333333333 0.7349
0.7367666666666667 0.7434
0.7457833333333334 0.7588
0.75935 0.7756
done
Process finished with exit code 0