tensorflow2.0模型保存为json格式

将模型网络结构保存为json格式

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 创建模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 模型网络结构转换成json格式
model_json = model.to_json()
print(model_json)

# 将json对象加载为模型
model = tf.keras.models.model_from_json(model_json)
print(model.summary())


模型结构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
None

json格式如下:

{
  "class_name": "Sequential",
  "config": {
    "name": "sequential",
    "layers": [
      {
        "class_name": "InputLayer",
        "config": {
          "batch_input_shape": [
            null,
            28,
            28
          ],
          "dtype": "float32",
          "sparse": false,
          "ragged": false,
          "name": "flatten_input"
        }
      },
      {
        "class_name": "Flatten",
        "config": {
          "name": "flatten",
          "trainable": true,
          "batch_input_shape": [
            null,
            28,
            28
          ],
          "dtype": "float32",
          "data_format": "channels_last"
        }
      },
      {
        "class_name": "Dense",
        "config": {
          "name": "dense",
          "trainable": true,
          "dtype": "float32",
          "units": 512,
          "activation": "relu",
          "use_bias": true,
          "kernel_initializer": {
            "class_name": "GlorotUniform",
            "config": {}
          },
          "bias_initializer": {
            "class_name": "Zeros",
            "config": {}
          }
        }
      },
      {
        "class_name": "Dropout",
        "config": {
          "name": "dropout",
          "trainable": true,
          "dtype": "float32",
          "rate": 0.2
        }
      },
      {
        "class_name": "Dense",
        "config": {
          "name": "dense_1",
          "trainable": true,
          "dtype": "float32",
          "units": 10,
          "activation": "softmax",
          "use_bias": true,
          "kernel_initializer": {
            "class_name": "GlorotUniform",
            "config": {}
          },
          "bias_initializer": {
            "class_name": "Zeros",
            "config": {}
          }
        }
      }
    ]
  },
  "keras_version": "2.4.0",
  "backend": "tensorflow"
}








个人资料
hadoop迷
等级:6
文章:30篇
访问:2.2w
排名: 13
上一篇: tensorflow2.0入门
下一篇:tensorflow2.0 网络模型保存为yaml
标签: tensorflow、json、面试题
隐藏