tensorflow2.0 网络模型保存为yaml

tensorflow2.0网络模型保存为yaml

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 创建模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 模型网络结构转换成yaml格式
model_yaml = model.to_yaml()
print(model_yaml)

# 将yaml对象加载为模型
model = tf.keras.models.model_from_yaml(model_yaml)
print(model.summary())


网络模型结构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
None

yaml数据如下:

backend: tensorflow
class_name: Sequential
config:
  layers:
  - class_name: InputLayer
    config:
      batch_input_shape: !!python/tuple
      - null
      - 28
      - 28
      dtype: float32
      name: flatten_input
      ragged: false
      sparse: false
  - class_name: Flatten
    config:
      batch_input_shape: !!python/tuple
      - null
      - 28
      - 28
      data_format: channels_last
      dtype: float32
      name: flatten
      trainable: true
  - class_name: Dense
    config:
      activation: relu
      activity_regularizer: null
      bias_constraint: null
      bias_initializer:
        class_name: Zeros
        config: {}
      bias_regularizer: null
      dtype: float32
      kernel_constraint: null
      kernel_initializer:
        class_name: GlorotUniform
        config:
          seed: null
      kernel_regularizer: null
      name: dense
      trainable: true
      units: 512
      use_bias: true
  - class_name: Dropout
    config:
      dtype: float32
      name: dropout
      noise_shape: null
      rate: 0.2
      seed: null
      trainable: true
  - class_name: Dense
    config:
      activation: softmax
      activity_regularizer: null
      bias_constraint: null
      bias_initializer:
        class_name: Zeros
        config: {}
      bias_regularizer: null
      dtype: float32
      kernel_constraint: null
      kernel_initializer:
        class_name: GlorotUniform
        config:
          seed: null
      kernel_regularizer: null
      name: dense_1
      trainable: true
      units: 10
      use_bias: true
  name: sequential
keras_version: 2.4.0






个人资料
hadoop迷
等级:6
文章:30篇
访问:2.2w
排名: 13
上一篇: tensorflow2.0模型保存为json格式
下一篇:tensorflow2.0 手写数字识别
标签: tensorflow、yaml、面试题
隐藏