# enhance_raw.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_enhance_raw
import platform
import os from tensorflow.python.tools import freeze_graph os.environ["CUDA_VISIBLE_DEVICES"] = "" if platform.system() == 'Windows':
data_dir = 'D:/data/LightOnOff/'
elif platform.system() == 'Linux':
data_dir = './dataset/LightOnOff/'
else:
print('platform not supported!')
assert False checkpoint_dir = './model_light_on_off/'
result_dir = './out_light_on_off/'
log_dir = './log_light_on_off/'
learning_rate = 1e-4
save_model_every_n_epoch = 10
max_epoch = 20000
if platform.system() == 'Windows':
save_output_every_n_steps = 1
else:
save_output_every_n_steps = 100 # BBF100-2
bbf_w = 4032
bbf_h = 3024 patch_h = 512
patch_w = 512 patch_h = 800
patch_w = 1024 max_level = 1023
black_level = 64 tf.reset_default_graph() # set up dataset
train_ids = os.listdir(data_dir)
train_ids.sort() def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl) def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0) # normalize the brightness
# out = np.minimum(out * 0.2 / np.maximum(1e-6, np.mean(out[:, :, 1])), 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_ def raw2rgb(raw): # GRBG
assert len(raw.shape)==3
h, w = raw.shape[0]<<1, raw.shape[1]<<1
rgb = np.zeros([h, w, 3])
rgb[0:h:2, 0:w:2, 1] = raw[:, :, 1]
rgb[0:h:2, 1:w:2, 0] = raw[:, :, 0]
rgb[1:h:2, 0:w:2, 2] = raw[:, :, 2]
rgb[1:h:2, 1:w:2, 1] = raw[:, :, 3]
return rgb def max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center):
return np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(left, left_top),
top),
top_right),
right),
right_bottom),
bottom),
bottom_left),
center) def demosaic(rgb):
for chn_id in range(3):
left = rgb[0:-2, 1:-1, chn_id]
left_top = rgb[0:-2, 0:-2, chn_id]
top = rgb[0:-2, 1:-1, chn_id]
top_right = rgb[0:-2, 2:, chn_id]
right = rgb[1:-1, 2:, chn_id]
right_bottom = rgb[2:, 2:, chn_id]
bottom = rgb[2:, 1:-1, chn_id]
bottom_left = rgb[2:, 0:-2, chn_id]
center = rgb[1:-1, 1:-1, chn_id]
rgb[1:-1, 1:-1, chn_id] = max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center)
return rgb def gray_ps(rgb):
return np.power(np.power(rgb[:, :, 0], 2.2) * 0.2973 + np.power(rgb[:,:,1], 2.2) * 0.6274 + np.power(rgb[:,:,2], 2.2) * 0.0753, 1/2.2) + 1e-7 def gamma_correction(x, curve_ratio):
gray_scale = np.expand_dims(gray_ps(x), axis=-1)
gray_scale_new = np.power(gray_scale, curve_ratio)
return np.minimum(x * gray_scale_new / gray_scale, 1.0) # setting the ratio of GPU global memory usage
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input')
gt_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4])
out_im = network_enhance_raw(in_im, patch_h, patch_w)
norm_im = tf.minimum(tf.maximum(out_im, 0.0), 1.0) ssim_loss = 1 - tf.image.ssim_multiscale(norm_im[0], gt_im[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(norm_im - gt_im), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(norm_im - gt_im), axis=-1))
# G_loss = ssim_loss
G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('MS-SSIM Loss', ssim_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss) t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded ' + ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path) # save the images for tracking training states
if not os.path.isdir(result_dir):
os.mkdir(result_dir) g_loss = np.zeros((500, 1)) merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph) gt_files = [None] * len(train_ids)
input_files = [None] * len(train_ids) input_images = [None] * len(train_ids)
gt_images = [None] * len(train_ids) for i in range(0, len(train_ids)):
gt_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*on*.dng')[0]
input_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*off*.dng')
input_images[i] = [None] * len(input_files[i]) steps = 0
st = time.time() for epoch in range(0, max_epoch):
for ind in np.random.permutation(len(train_ids)):
steps += 1
sid = np.random.randint(0, len(input_files[ind]))
if input_images[ind][sid] is None:
input_images[ind][sid] = np.expand_dims(pack_raw_bbf(input_files[ind][sid]), axis=0)
if gt_images[ind] is None:
gt_images[ind] = np.expand_dims(np.maximum(pack_raw_bbf(gt_files[ind]), 0), axis=0) # random cropping
xx = np.random.randint(0, bbf_w//2 - patch_w)
yy = np.random.randint(0, bbf_h//2 - patch_h)
input_patch = np.float32(input_images[ind][sid][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level)
gt_patch = np.float32(gt_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level) # random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
gt_patch = np.flip(gt_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
gt_patch = np.flip(gt_patch, axis=0)
# if np.random.randint(2, size=1)[0] == 1: # random transpose
# input_patch = np.transpose(input_patch, (0, 2, 1, 3))
# gt_patch = np.transpose(gt_patch, (0, 2, 1, 3)) # summary, _, G_current, output = sess.run(
# [merged, G_opt, G_loss, out_im],
# feed_dict={
# in_im: input_patch,
# gt_im: gt_patch,
# lr: learning_rate})
# g_loss[ind] = G_current summary, output = sess.run(
[merged, out_im],
feed_dict={
in_im: input_patch,
gt_im: gt_patch,
lr: learning_rate
}) # saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
# print('model saved.')
# exit(0) tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model_raw2raw.pb')
freeze_graph.freeze_graph(
'output_model/pb_model/model_raw2raw.pb',
'',
False,
'./model_light_on_off/0.ckpt',
'gen/output',
'save/restore_all',
'save/Const:0',
'output_model/pb_model/frozen_model.pb',
True,
"")
exit(0) if steps % save_output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st)/save_output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
# save the current output image for network inspection
out_ = np.minimum(np.maximum(output, 0), 1)
in_rgb = gamma_correction(demosaic(raw2rgb(input_patch[0])), 0.35)
gt_rgb = gamma_correction(demosaic(raw2rgb(gt_patch[0])), 0.35)
out_rgb = gamma_correction(demosaic(raw2rgb(out_[0])), 0.35)
temp = np.concatenate((in_rgb, gt_rgb, out_rgb), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255)\
.save(result_dir + '/%d_%s_00.jpg' % (epoch, train_ids[ind])) # clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind][sid] = None
gt_images[ind] = None if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')

采用Tensorflow内部函数直接对模型进行冻结的更多相关文章

  1. tensorflow加载embedding模型进行可视化

    1.功能 采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量 ps:采用C++版本训练的w2v模型,python的gensi ...

  2. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  3. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

  4. 开园第一篇---有关tensorflow加载不同模型的问题

    写在前面 今天刚刚开通博客,主要想法跟之前某位博主说的一样,希望通过博客园把每天努力的点滴记录下来,也算一种坚持的动力.我是小白一枚,有啥问题欢迎各位大神指教,鞠躬~~ 换了新工作,目前手头是OCR项 ...

  5. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  6. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  7. 【TensorFlow】基于ssd_mobilenet模型实现目标检测

    最近工作的项目使用了TensorFlow中的目标检测技术,通过训练自己的样本集得到模型来识别游戏中的物体,在这里总结下. 本文介绍在Windows系统下,使用TensorFlow的object det ...

  8. TensorFlow学习笔记12-word2vec模型

    为什么学习word2word2vec模型? 该模型用来学习文字的向量表示.图像和音频可以直接处理原始像素点和音频中功率谱密度的强度值, 把它们直接编码成向量数据集.但在"自然语言处理&quo ...

  9. tensorflow之逻辑回归模型实现

    前面一篇介绍了用tensorflow实现线性回归模型预测sklearn内置的波士顿房价,现在这一篇就记一下用逻辑回归分类sklearn提供的乳腺癌数据集,该数据集有569个样本,每个样本有30维,为二 ...

随机推荐

  1. 一个令人蛋疼的 Microsoft.AspNet.FriendlyUrls

    我一个项目都基本上做完了,结果部署到我服务器的时候结果一直报404 找不到 一看global.asax有个路由注册的代码 public static void RegisterRoutes(Route ...

  2. 关于MySQL性能的比较

    需求:在传递一组职位编号的时候,需要统计该职位的 当天的投递情况 和 有历史记录以来总的投递量 解决方案一: 每次都进行一次数据库查询,遍历职位id,再根据职位id去查询相应时间内的投递量 /** * ...

  3. 仿淘宝左侧菜单导航栏纯Html + css 写的

    这俩天闲来没事淘宝逛了一圈看到淘宝的左侧导航菜单做的是真心的棒啊,一时兴起,查了点资料抓了几个图片仿淘宝写了个css,时间紧写的不太好,大神勿喷,给小白做个参考 废话不多说先来个效果图 接下来直接上代 ...

  4. Flask-Moment----探索

    前言:  Flask-Moment在所有的flask扩展中算是相对简单的一个了,但是还是有很多需要理解的地方.那么今天就跟着笔者一起,来学习一下flask-moment在flask项目中的应用. 首先 ...

  5. freemarker处理map的数据(二十)

    1.简易说明 (1)map取值 (2)key取值 2.实现示例 <html> <head> <meta http-equiv="content-type&quo ...

  6. git远程删除分支后,本地git branch -a 依然能看到的解决办法

    http://blog.csdn.net/qq_16885135/article/details/52777871 使用 git branch -a 命令可以查http://blog.csdn.net ...

  7. 第七章 鼠标(CHECKER1)

    CHECKER1程序将客户区划分成25个矩形,构成一个5*5的数组.如果在其中一个矩形内单击鼠标,就用X形填充该矩形.再次单击,则X形消失. /*--------------------------- ...

  8. ngx_lua_waf

    Web应用防护系统Web Application Firewall,简称WAF.针对HTTP/HTTPS的安全策略专门为Web应用提供保护的产品. OpenResty是一个基于 Nginx 与 Lua ...

  9. iframe-metamask

    iframe--require('iframe') higher level api for creating and removing iframes in browsers 用于创建或移除浏览器中 ...

  10. php7 扩展模块添加

    php 扩展模块添加   1. 新增安装扩展模块的位置 [root@node_22 ~]# ls /usr/local/php7/lib/php/extensions/no-debug-non-zts ...