avatar

resnet + flask

用resnet模型来训练mnist手写数字集是大材小用了,不过作为来学习也说得过去
近来学习深度学习每次都是训练完数据,loss降低,accuracy上升之后便保存模型没了后续,
最近找到一个项目便是将深度学习模型部署在flask,大大增强趣味性与交互性。

原项目地址:ybsdegit/Keras_flask_mnist: 基于 Keras + Flask 的 Mnist 手写数字集识别系统 (github.com)
界面展示

下边是训练过程, 可以看到在测试集上的正确率也高达98%,效果甚佳。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Epoch 2/10
469/469 [==============================] - 200s 426ms/step - loss: 0.0267 - accuracy: 0.9916 - val_loss: 0.0494 - val_accuracy: 0.9863
Epoch 3/10
469/469 [==============================] - 204s 434ms/step - loss: 0.0175 - accuracy: 0.9947 - val_loss: 0.0567 - val_accuracy: 0.9823
Epoch 4/10
469/469 [==============================] - 204s 435ms/step - loss: 0.0121 - accuracy: 0.9960 - val_loss: 0.0994 - val_accuracy: 0.9721
Epoch 5/10
469/469 [==============================] - 204s 434ms/step - loss: 0.0108 - accuracy: 0.9965 - val_loss: 0.0414 - val_accuracy: 0.9894
Epoch 6/10
469/469 [==============================] - 204s 435ms/step - loss: 0.0105 - accuracy: 0.9967 - val_loss: 0.0907 - val_accuracy: 0.9758
Epoch 7/10
469/469 [==============================] - 204s 435ms/step - loss: 0.0071 - accuracy: 0.9977 - val_loss: 0.0489 - val_accuracy: 0.9857
Epoch 8/10
469/469 [==============================] - 204s 435ms/step - loss: 0.0099 - accuracy: 0.9967 - val_loss: 0.0446 - val_accuracy: 0.9888

resnet.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import tensorflow as tf


class BasicBlock(tf.keras.layers.Layer):

def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()

self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=strides, padding='same')
self.bn1 = tf.keras.layers.BatchNormalization()
self.relu = tf.keras.layers.Activation('relu')

self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=1, padding='same')
self.bn2 = tf.keras.layers.BatchNormalization()

if strides != 1:
self.downsample = tf.keras.Sequential()
self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(1, 1), strides=strides))
else:
self.downsample = lambda x: x

def call(self, inputs, training=None):

out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

identity = self.downsample(inputs)

output = tf.keras.layers.add([out, identity])
output = tf.nn.relu(output)

return output


class ResNet(tf.keras.Model):

def __init__(self, layer_dims, num_classes=10):
super(ResNet, self).__init__()

self.stem = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')])

self.layer1 = self.build_resblock(filter_num=64, blocks=layer_dims[0])
self.layer2 = self.build_resblock(filter_num=128, blocks=layer_dims[1], strides=2)
self.layer3 = self.build_resblock(filter_num=256, blocks=layer_dims[2], strides=2)
self.layer4 = self.build_resblock(filter_num=512, blocks=layer_dims[3], strides=2)

# output:[b, 512, h, w]
self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
self.fully_con = tf.keras.layers.Dense(num_classes)

def call(self, inputs, training=None, mask=None):

x = self.stem(inputs)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

# [b, c]
x = self.avgpool(x)
# [b, 100]
x = self.fully_con(x)

return x

def build_resblock(self, filter_num, blocks, strides=1):

res_blocks = tf.keras.Sequential()
res_blocks.add(BasicBlock(filter_num, strides))

for _ in range(1, blocks):
res_blocks.add(BasicBlock(filter_num, strides=1))

return res_blocks


def resnet18():

return ResNet(layer_dims=[2, 2, 2, 2])


def resnet34():

return ResNet(layer_dims=[3, 4, 6, 3])


# model = resnet34()
# model.build(input_shape=(None, 32, 32, 3))
# model.summary()
#
# model2 = resnet18()
# model2.build(input_shape=(None, 32, 32, 3))
# model2.summary()

train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
from RESNET import resnet18
import numpy as np


def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255
y = tf.cast(y, dtype=tf.int32)
return x, y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
y_train = tf.one_hot(y_train, depth=10)
y_test = tf.one_hot(y_test, depth=10)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.map(preprocess).shuffle(10000).batch(128)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(128)


# sample = next(iter(db_train))
# print(sample[0].shape, sample[1].shape)


def main():

model = resnet18()
model.build(input_shape=(None, 28, 28, 1))
model.summary()

model.compile(optimizer=tf.optimizers.Adam(lr=1e-4),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

model.fit(db_train, epochs=10, validation_data=db_test, validation_freq=1)
model.save("my_model")


if __name__ == '__main__':
main()

app.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import re
import base64
import numpy as np
import tensorflow as tf
from flask import Flask, render_template, request
from tensorflow.keras.preprocessing.image import img_to_array, load_img

app = Flask(__name__)

model_file = './model/my_model'
global model

model = tf.keras.models.load_model(model_file)

@app.route('/')
def index():
return render_template("index.html")

@app.route('/predict/', methods=['Get', 'POST'])
def predict():

Parse_Image(request.get_data())
img = img_to_array(load_img('output.png', target_size=(28, 28), color_mode="grayscale")) / 255.
img = np.expand_dims(img, axis=0)
# code = model.predict_classes(img)[0] 使用resnet模型时无法用这个方法
code = model.predict(img)
code = np.argmax(code)
return str(code)


def Parse_Image(imgData):

imgStr = re.search(b'base64,(.*)', imgData).group(1)
with open('./output.png', 'wb') as output:
output.write(base64.decodebytes(imgStr))


if __name__ == '__main__':
app.run(host="127.0.0.1", port=3335)

文章作者: gh
文章链接: https://ghclub.top/posts/51850/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 GHBlog
打赏
  • 微信
    微信
  • 支付寶
    支付寶

评论