手写数字识别算法的研究意义(104.人工智能手写汉字数字识别)
在前一章,对手写汉字数字的数据进行了预处理,本文准备使用ResNet18卷积神经网络来实现手写汉字数字的识别。关于ResNet网络模型和代码可以参看:101.人工智能——构建残差网络ResNet18网络模型。
自定数据集
#自定义数据集
import paddle
import paddle.vision.transforms as T
class MyDataset(paddle.io.Dataset):
def __init__(self,data,mode="train"):
self.data=data
self.mode=mode
def transform(self,mode):
if mode=="train":
return T.Compose([
T.ToTensor(),
T.Normalize()
])
else:
return T.Compose([
T.ToTensor(),
T.Normalize()
])
def __getitem__(self,idx):
img=mpimg.imread(os.path.join(datadir,self.data[idx][0]))
img=self.transform(self.mode)(img)
label=self.data[idx][1]
label=np.array(label).astype("int64")
label=np.reshape(label,(1))
return img,label
def __len__(self):
return len(self.data)
#返回数据集
train_dataset=MyDataset(traindata,"train")
val_dataset=MyDataset(valdata,"val")
test_dataset=MyDataset(testdata,"test")
train_loader=paddle.io.DataLoader(train_dataset,batch_size=16,shuffle=True)
val_loader=paddle.io.DataLoader(val_dataset,batch_size=16,shuffle=False)
test_loader=paddle.io.DataLoader(test_dataset,batch_size=16,shuffle=False)
#查看数据集形状
print(train_dataset.data.shape,val_dataset.data.shape,test_dataset.data.shape)
#查看批次数据
for i,data in enumerate(train_loader()):
img,label=data
print(img.shape,label.shape)
break
#运行结果
(10500, 2) (3000, 2) (1500, 2)
[16, 3, 64, 64] [16, 1]
详细代码本文这里省略。可以参看:101.人工智能——构建残差网络ResNet18网络模型。
模型训练
#查看模型结构
model = Model_ResNet18(in_channels=3, num_classes=15, use_residual=True)
params_info = paddle.summary(model, (1, 3, 64, 64))
print(params_info)
#开始训练
def train(model,train_loader,epochs=10):
#定义优化器
opt=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters())
bestacc=0
############################
# # 读取参数文件(恢复训练,1:加载最后的训练模型文件和优化参数文件)
# params_dict = paddle.load("models/finally.pdparams")
# opt_dict = paddle.load("models/finally.pdopt")
# # 加载参数到模型
# model.set_state_dict(params_dict)
# # 加载参数到优化器
# opt.set_state_dict(opt_dict)
############################
print('start training ... ')
model.train()
for epoch in range(epochs):
for i,data in enumerate(train_loader()):
img,label=data
#转换数据类型
img=paddle.to_tensor(img)
label=paddle.to_tensor(label)
#前向计算,获取损失函数值
out=model(img)
loss=F.cross_entropy(out,label)
avg_loss=paddle.mean(loss)
#反向传播,更新参数,清空梯度
avg_loss.backward()
opt.step()
opt.clear_gradients()
print(f"epoch:{epoch} loss:{avg_loss.numpy()}")
#验证模型
model.eval()
accs=[]
losses=[]
for i,data in enumerate(val_loader()):
img,label=data
img=paddle.to_tensor(img)
label=paddle.to_tensor(label)
out=model(img)
loss=F.cross_entropy(out,label)
avg_loss=paddle.mean(loss)
acc=paddle.metric.accuracy(out,label)
accs.append(acc.numpy())
losses.append(avg_loss.numpy())
print(f"epoch:{epoch},loss:{np.mean(losses)},acc:{np.mean(accs)}")
#保存最佳模型
if np.mean(accs)>bestacc:
bestacc=np.mean(accs)
print(f"save best model,acc:{bestacc},epoch:{epoch}")
paddle.save(model.state_dict(),"models/rawnum_resnet18_best.pdparams")
paddle.save(opt.state_dict(), 'models/rawnum_resnet18_best.pdopt')
model.train() #恢复训练模式
train(model,train_loader,epochs=10)
#训练过程,在CPU环境下训练时间很长……但从训练结果来看,准确率还在达到90%以上。
epoch:0 loss:[0.10153195]
epoch:0,loss:0.23171958327293396,acc:0.9301861524581909
save best model,acc:0.9301861524581909,epoch:0
epoch:1 loss:[0.02970559]
epoch:1,loss:0.28862079977989197,acc:0.9135638475418091
epoch:2 loss:[0.8410681]
epoch:2,loss:0.3803146481513977,acc:0.904920220375061
………………
#加载模型、预测模型
model_dict=paddle.load("models/rawnum_resnet18_best.pdparams")
model.load_dict(model_dict)
model.eval()
#随机取一条测试数据
idx=np.random.randint(len(test_dataset))
img,label=test_dataset[idx]
img=np.reshape(img,(1,3,64,64))
label=np.reshape(label,(1,1))
#print(img.shape,label.shape,label.item())
results=model(paddle.to_tensor(img))
predictlabel=np.argmax(results.numpy()) #最大值的索引,用argmax
print(f"predict:{predictlabel},label:{label.item()}")
#预测结果:随机运行5次
predict:5,label:5
predict:10,label:10
predict:6,label:5
predict:1,label:1
predict:5,label:5
从预测结果来看,和准确率90%以上是相符的,没有达到100%的识别效果。
,
免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com