【ノイズ除去】kerasでノイズ除去を試してみた2【Win5-RB】

はじめに

こんにちは、がんがんです。
前回、ノイズ除去を目的としたDNCNNに関する備忘録を書きました。
前回の記事はこちらからどうぞ。
gangannikki.hatenadiary.jp

コードが長くなってしまうため、記事を分けて書くことにしました。
今回はWin5-RBについて書いています。

概要

  1. ノイズ除去のDLモデルを試してみた
  2. cifer10を用いてノイズ除去の実験を実施

参考記事

こちらの記事を中心に実験しました。
qiita.com

また、論文も掲載しておきます。

https://arxiv.org/pdf/1707.05414.pdf

Win5-RB

Win5-RBモデルではWide Inference Network(Win)を5層使用し、それをResnetで接続しているためにこの名前がついています。
BはBatchNormalizationのBです。

モデルはこんな形をしています。

f:id:gangannikki:20190126234329p:plain
アーキテクチャ

実行結果

実験は今回もcifer10にて行っています。

f:id:gangannikki:20190126234814p:plain
上段:入力画像 中段:ノイズ付与画像 下段:出力画像

コード

#------------------------------------------------------------
#
#    Win5-RBを作成
#      Win5-RB_cifer10_keras.py
#
#------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
from keras import layers
from keras.datasets import cifar10
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Activation
from keras.models import Model
from keras.utils import plot_model
from keras.callbacks import CSVLogger, ModelCheckpoint

import tools

"""
	モデル
"""
def DNN( x_train_noisy, x_train,
         x_test_noisy, x_test,
         x_validation_noisy, x_validation,
         epochs, batch_size ):
	"""
		モデルの構築
	"""
	#  入力層
	input_img = Input( shape=( 32, 32, 3) )	#  32×32、RGB

	#  中間層
	x = Conv2D( 64, (7, 7), padding='same' )(input_img)
	X = BatchNormalization()(x)
	x = Activation('relu')(x)

	for i in range(3):
		x = Conv2D( 64, (7, 7), padding='same' )(x)
		X = BatchNormalization()(x)
		x = Activation('relu')(x)

	#  出力層
	x = Conv2D( 3, (7, 7), padding='same' )(x)
	x = BatchNormalization()(x)
	output_img = layers.add([ x, input_img ])

	"""
		モデルのコンパイル
	"""
	model = Model( input_img, output_img )
	model.compile( optimizer='adam',
	               loss='mean_squared_error',
	               metrics=["accuracy"] )

	#  アーキテクチャの可視化
	model.summary()	#  ディスプレイ上に表示
	plot_model( model, to_file="architecture.png" )

	"""
		モデルの学習
	"""
	csv_logger = CSVLogger("./training.log")
	check_point = ModelCheckpoint( filepath="./model/model.{epoch:02d}-{val_loss:.4f}.hdf5",
				       monitor="val_loss",
				       save_best_only=True,
				       mode="auto" )

	cb = [ csv_logger, check_point ]

	#  学習							   				    
	hist = model.fit( x_train_noisy, x_train,
	                  epochs=epochs,
                          batch_size=batch_size,
                          shuffle=True,
                          validation_data=(x_validation_noisy, x_validation),
                          callbacks=cb )

	"""
		モデルの評価
	"""			 
	score = model.evaluate( x_test_noisy, x_test )
	print(score)

	#  グラフの表示
	tools.plot_history( hist, epochs )
	#  画像の表示
	c10test = model.predict( x_test_noisy )
	tools.graph( x_test, x_test_noisy, c10test )

"""
	データの成形、グラフ化
"""
if __name__ == '__main__':
	#	Cifer10のLoad, データの分割
	num_classes = 10
	(x_train, y_train), (x_test, y_test) = cifar10.load_data()	#  学習:50000, テスト:10000
	x_train = x_train.astype('float32')
	x_test = x_test.astype('float32')
	x_train /= 255
	x_test /= 255

	x_validation = x_test[:7000]	#  validation_data : ( 7000, 32, 32, 3)
	x_test = x_test[7000:]		#  test_data : ( 3000, 32, 32, 3)

	#  ノイズ付与
	noise_factor = 0.1
	x_train_noisy, x_test_noisy, x_validation_noisy = \
		tools.add_noise( noise_factor, x_train, x_test, x_validation )
	
	epochs = 100
	batch_size = 32

	#  モデル関数
	DNN( x_train_noisy, x_train,
	     x_test_noisy, x_test,
	     x_validation_noisy, x_validation,
	     epochs, batch_size )

まとめ

今回は前回の続きからノイズ除去に関しての備忘録を書きました。
見た目的にはDNCNNの方が綺麗になっている印象がありますね。

個人的には異常検知とかに興味があるので、そちらのモデルとかも作ったりしていきたいです。