小白劝退预告
仅简单介绍思路,没有用作教程的打算,如果读者没有深度学习基础、计算机视觉基础、Pytorch基础、前后端基础和Flask基础 —— 会很不友好的(
项目架构 一个手写数字识别的Web应用可以分为以下几个部分:
神经网络:一个输入 (28, 28) 的矩阵后得到一个输出的神经网络,用于识别矩阵对应的数字
前端页面:为用户提供可交互的界面,与用户发生交互,并将对应的手写图片发送给后端
后端处理:接收前端提供的手写图片,并对手写图片进行矩阵化处理,经过神经网络识别后将结果发送给前端
在本项目中,主要用到以下的技术栈:
Pytorch 实现并训练的神经网络,通过 CUDA 加速计算
使用 Pytorch-MNIST 数据集
用 Flask 来构建后端,通过 OpenCV-Python 来预处理图片数据
原生前端,通过 JavaScript 的 Ajax 请求来与后端交换数据
神经网络 识别数字的过程交给神经网络来操作
训练神经网络 一个简单的卷积神经网络模型,先进行三次卷积操作,最后通过全连接到十个输出,分别对应数字 0 - 9 的的概率比重神经网络结构如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class CNN (nn.Module): def __init__ (self ): super (CNN, self).__init__() self.c1 = nn.Conv2d(1 , 6 , 5 ) self.c2 = nn.Conv2d(6 , 16 , 5 ) self.fc1 = nn.Linear(16 * 4 * 4 , 120 ) self.fc2 = nn.Linear(120 , 84 ) self.fc3 = nn.Linear(84 , 10 ) def forward (self, X ): X = F.max_pool2d(F.relu(self.c1(X)), 2 ) X = F.max_pool2d(F.relu(self.c2(X)), 2 ) X = X.view(-1 , 16 * 4 * 4 ) X = F.relu(self.fc1(X)) X = F.relu(self.fc2(X)) X = self.fc3(X) return X
采用 Pytorch 自带的 MNIST 数据集,里面包含了 60000 张手写数字 采用 Pytorch 自带的正则化和张量化的函数,能一键将对应的 (28, 28) 的图片变为利于神经网络计算的输入代码如下:
1 2 3 4 5 6 transform = transforms.ToTensor() trainset = datasets.MNIST('data' , train=True , download=True , transform=transform) testset = datasets.MNIST('data' , train=False , download=True , transform=transform) trainloader = DataLoader(trainset, shuffle=True , batch_size=4 , num_workers=2 ) testloader = DataLoader(testset, shuffle=False , batch_size=4 , num_workers=2 )
接下来是定义的一些超参数,放在这里仅作参考 具体的训练和检测过程就不再说明了代码如下:
1 2 3 4 5 6 transform = transforms.ToTensor() model = CNN().cuda() learning_rate = 1e-2 epochs = 1 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate)
保存神经网络 下面简单演示了一下模型的保存 每一千次迭代我们便保存一次模型,最后选择正确率最高的那一个模型作为我们最终的神经网络代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def saveModel (model, path ): torch.save(model, path) return def train (epochs = 1 ): for epoch in range (1 , epochs + 1 , 1 ): ... for step, data in enumerate (trainloader, 0 ): ... if step % 1000 == 0 : ... saveModel(model, f"model/CNN/{epochs} -{step} -{acc:.2 f} %.pth" ) print ("Finished training!" )
前端页面 前端提供了用户交互的场所,这里我们使用原生的 HTML、CSS 和 JavaScript
基本界面
使用 canvas 作为画板
使用两个 button 作为触发事件的开关
代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 <!DOCTYPE html > <html > <head > <meta charset ="UTF-8" > <title > MNIST WEB BY Tokisakix</title > </head > <body > <h1 > MNIST + Flask</h1 > <div class ="centered" > <a href ="https://github.com/Tokisakix" > <img style ="position: absolute; top: 0; right: 0; border: 0;" src ="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt ="Fork me on GitHub" data-canonical-src ="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png" > </a > <canvas id ="canvas" width ="128" height ="128" > </canvas > </div > <div class ="centered" > <input type ="button" class ="myButton" value ="Predict" > <input type ="button" id ="clearButton" value ="Clear" > </div > <div class ="centered" > <h1 id ="result" > </h1 > </div > </body > </html >
前端接口 使用 JS 内置的鼠标事件操作在画板上绘画,并对应修改画板上的颜色分布代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 (function ( ) { var canvas = document .querySelector ("#canvas" ); var context = canvas.getContext ("2d" ); canvas.width = 280 ; canvas.height = 280 ; var Mouse = {x : 0 , y : 0 }; var lastMouse = {x : 0 , y : 0 }; context.fillStyle = "black" ; context.fillRect (0 , 0 , canvas.width , canvas.height ); context.color = "white" ; context.lineWidth = 10 ; context.lineJoin = context.lineCap = 'round' ; debug (); canvas.addEventListener ("mousemove" , function (e ) { lastMouse.x = Mouse .x ; lastMouse.y = Mouse .y ; Mouse .x = e.pageX - this .offsetLeft - 15 ; Mouse .y = e.pageY - this .offsetTop - 15 ; }, false ); canvas.addEventListener ("mousedown" , function (e ) { canvas.addEventListener ("mousemove" , onPaint, false ); }, false ); canvas.addEventListener ("mouseup" , function ( ) { canvas.removeEventListener ("mousemove" , onPaint, false ); }, false ); var onPaint = function ( ) { context.lineWidth = context.lineWidth ; context.lineJoin = "round" ; context.lineCap = "round" ; context.strokeStyle = context.color ; context.beginPath (); context.moveTo (lastMouse.x , lastMouse.y ); context.lineTo (Mouse .x , Mouse .y ); context.closePath (); context.stroke (); }; function debug ( ) { $("#clearButton" ).on ("click" , function ( ) { context.clearRect (0 , 0 , 280 , 280 ); context.fillStyle = "black" ; context.fillRect (0 , 0 , canvas.width , canvas.height ); }); } }());
当用户点击 submit 时,发送一个 HTTP 请求到后端,接口如下:
POST /predict data | answer
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 <script type="text/javascript" > $(".myButton" ).click (function ( ){ $('#result' ).text ('Predicting...' ); var canvasObj = document .getElementById ("canvas" ); var img = canvasObj.toDataURL ('image/png' ); $.ajax ({ type : "POST" , url : "/predict" , data : img, success : function (data ){ $('#result' ).text ('Predicted Output: ' + data); } }); }); </script>
后端处理 基本架构 即 Flask 的基本架构,代码如下:
1 2 3 4 5 6 7 8 9 10 11 from flask import Flask, render_template, requestimport base64app = Flask(__name__) @app.route("/" ) def Index (): return render_template("index.html" ) if __name__ == "__main__" : app.run("127.0.0.1" , 5500 )
后端接口 transfer 函数可以将对应路径的图片转换成能通过神经网络识别的张量数据predict 函数将图片数据输入到神经网络中,并返回对应的预测结果Predict 函数对接收到的 HTTP 请求做解析,将 POST 的图片数据保存到本地,并通过 transfer 和 predict 函数来计算得到预测结果,最终返回给前端
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 import cv2 as cvdef transfer (path ): img = cv.imread(path, cv.IMREAD_GRAYSCALE) img = cv.resize(img, (28 , 28 )) data = torch.tensor(img, dtype=torch.float ) return data def predict (path ): image = transfer(path) inputs = torch.reshape(image, [1 , 28 , 28 ]).cuda() outputs = model(inputs).cpu() predicts = torch.max (outputs, 1 ).indices.item() return predicts @app.route("/predict" , methods=["POST" ] ) def Predict (): dataurl = str (request.get_data()) dataurl = dataurl[24 :len (dataurl) - 1 ] data = base64.b64decode(dataurl) with open ("temp.jpg" , "+wb" ) as img: img.write(data) result = predict("temp.jpg" ) return str (result)
本地测试 此项目所有的代码可在此链接下载