小白劝退预告

  • 仅简单介绍思路,没有用作教程的打算,如果读者没有深度学习基础、计算机视觉基础、Pytorch基础、前后端基础和Flask基础 —— 会很不友好的(

项目架构

一个手写数字识别的Web应用可以分为以下几个部分:

  1. 神经网络:一个输入 (28, 28) 的矩阵后得到一个输出的神经网络,用于识别矩阵对应的数字
  2. 前端页面:为用户提供可交互的界面,与用户发生交互,并将对应的手写图片发送给后端
  3. 后端处理:接收前端提供的手写图片,并对手写图片进行矩阵化处理,经过神经网络识别后将结果发送给前端

在本项目中,主要用到以下的技术栈:

  • Pytorch 实现并训练的神经网络,通过 CUDA 加速计算
  • 使用 Pytorch-MNIST 数据集
  • Flask 来构建后端,通过 OpenCV-Python 来预处理图片数据
  • 原生前端,通过 JavaScriptAjax 请求来与后端交换数据

神经网络

识别数字的过程交给神经网络来操作

训练神经网络

一个简单的卷积神经网络模型,先进行三次卷积操作,最后通过全连接到十个输出,分别对应数字 0 - 9 的的概率比重
神经网络结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.c1 = nn.Conv2d(1, 6, 5)
self.c2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, X):
X = F.max_pool2d(F.relu(self.c1(X)), 2)
X = F.max_pool2d(F.relu(self.c2(X)), 2)
X = X.view(-1, 16 * 4 * 4)
X = F.relu(self.fc1(X))
X = F.relu(self.fc2(X))
X = self.fc3(X)
return X

采用 Pytorch 自带的 MNIST 数据集,里面包含了 60000 张手写数字
采用 Pytorch 自带的正则化和张量化的函数,能一键将对应的 (28, 28) 的图片变为利于神经网络计算的输入
代码如下:

1
2
3
4
5
6
transform = transforms.ToTensor()
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
testset = datasets.MNIST('data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, shuffle=True, batch_size=4, num_workers=2)
testloader = DataLoader(testset, shuffle=False, batch_size=4, num_workers=2)

接下来是定义的一些超参数,放在这里仅作参考
具体的训练和检测过程就不再说明了
代码如下:

1
2
3
4
5
6
transform = transforms.ToTensor()
model = CNN().cuda()
learning_rate = 1e-2
epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

保存神经网络

下面简单演示了一下模型的保存
每一千次迭代我们便保存一次模型,最后选择正确率最高的那一个模型作为我们最终的神经网络
代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def saveModel(model, path):
torch.save(model, path)
return

def train(epochs = 1):
for epoch in range(1, epochs + 1, 1):
...
for step, data in enumerate(trainloader, 0):
...

if step % 1000 == 0:
...
saveModel(model, f"model/CNN/{epochs}-{step}-{acc:.2f}%.pth")
print("Finished training!")

前端页面

前端提供了用户交互的场所,这里我们使用原生的 HTML、CSS 和 JavaScript

基本界面

  1. 使用 canvas 作为画板
  2. 使用两个 button 作为触发事件的开关

代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>MNIST WEB BY Tokisakix</title>
</head>

<body>
<h1>MNIST + Flask</h1>

<div class="centered">
<a href="https://github.com/Tokisakix">
<img style="position: absolute; top: 0; right: 0; border: 0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
</a>
<canvas id="canvas" width="128" height="128"></canvas>
</div>

<div class="centered">
<input type="button" class="myButton" value="Predict">
<input type="button" id="clearButton" value="Clear">
</div>

<div class="centered">
<h1 id="result"></h1>
</div>
</body>
</html>

前端接口

使用 JS 内置的鼠标事件操作在画板上绘画,并对应修改画板上的颜色分布
代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
(function () {
var canvas = document.querySelector("#canvas");
var context = canvas.getContext("2d");
canvas.width = 280;
canvas.height = 280;

var Mouse = {x: 0, y: 0};
var lastMouse = {x: 0, y: 0};
context.fillStyle = "black";
context.fillRect(0, 0, canvas.width, canvas.height);
context.color = "white";
context.lineWidth = 10;
context.lineJoin = context.lineCap = 'round';

debug();

canvas.addEventListener("mousemove", function (e) {
lastMouse.x = Mouse.x;
lastMouse.y = Mouse.y;

Mouse.x = e.pageX - this.offsetLeft - 15;
Mouse.y = e.pageY - this.offsetTop - 15;
}, false);

canvas.addEventListener("mousedown", function (e) {
canvas.addEventListener("mousemove", onPaint, false);
}, false);

canvas.addEventListener("mouseup", function () {
canvas.removeEventListener("mousemove", onPaint, false);
}, false);

var onPaint = function () {
context.lineWidth = context.lineWidth;
context.lineJoin = "round";
context.lineCap = "round";
context.strokeStyle = context.color;

context.beginPath();
context.moveTo(lastMouse.x, lastMouse.y);
context.lineTo(Mouse.x, Mouse.y);
context.closePath();
context.stroke();
};

function debug() {
$("#clearButton").on("click", function () {
context.clearRect(0, 0, 280, 280);
context.fillStyle = "black";
context.fillRect(0, 0, canvas.width, canvas.height);
});
}
}());

当用户点击 submit 时,发送一个 HTTP 请求到后端,接口如下:

POST /predict data | answer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
   <script type="text/javascript">
$(".myButton").click(function(){
$('#result').text('Predicting...');
var canvasObj = document.getElementById("canvas");
var img = canvasObj.toDataURL('image/png');
$.ajax({
type: "POST",
url: "/predict",
data: img,
success: function(data){
$('#result').text('Predicted Output: ' + data);
}
});
});
</script>

后端处理

基本架构

Flask 的基本架构,代码如下:

1
2
3
4
5
6
7
8
9
10
11
from flask import Flask, render_template, request
import base64

app = Flask(__name__)

@app.route("/")
def Index():
return render_template("index.html")

if __name__ == "__main__":
app.run("127.0.0.1", 5500)

后端接口

transfer 函数可以将对应路径的图片转换成能通过神经网络识别的张量数据
predict 函数将图片数据输入到神经网络中,并返回对应的预测结果
Predict 函数对接收到的 HTTP 请求做解析,将 POST 的图片数据保存到本地,并通过 transferpredict 函数来计算得到预测结果,最终返回给前端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import cv2 as cv

def transfer(path):
img = cv.imread(path, cv.IMREAD_GRAYSCALE)
img = cv.resize(img, (28, 28))
data = torch.tensor(img, dtype=torch.float)
return data

def predict(path):
image = transfer(path)
inputs = torch.reshape(image, [1, 28, 28]).cuda()
outputs = model(inputs).cpu()
predicts = torch.max(outputs, 1).indices.item()
return predicts

@app.route("/predict", methods=["POST"])
def Predict():
dataurl = str(request.get_data())
dataurl = dataurl[24:len(dataurl) - 1]
data = base64.b64decode(dataurl)
with open("temp.jpg", "+wb") as img:
img.write(data)
result = predict("temp.jpg")
return str(result)

本地测试

此项目所有的代码可在此链接下载

  • 于是它把 8 识别成了 2(