tensorflow2.0 手写数字识别训练、预测

tensorflow2.0 手写数字识别训练、预测

代码如下:

import numpy as np
import tensorflow as tf

# 加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)


# 训练
def train():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train, batch_size=64, epochs=10,
              validation_data=(x_test, y_test))

    evaluate_result = model.evaluate(x_test, y_test)
    print(evaluate_result)

    # 只保留网络结构
    model.save("data/model2.h5")


# 预测
def predict():
    model = tf.keras.models.load_model("data/model2.h5")

    for i in range(10):
        x = x_test[i].reshape((-1, 28, 28))
        y_prob = model.predict(x)
        y_pred = np.argmax(y_prob)
        print("actual: {},predict:{},result:{}".format(y_test[i], y_pred, y_test[i] == y_pred))


# train()
predict()




个人资料
hadoop迷
等级:6
文章:30篇
访问:2.2w
排名: 13
上一篇: tensorflow2.0 手写数字识别
下一篇:基于四大名著训练中文词向量
标签: tensorflow、面试题
隐藏