# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
# import paths
# quantization operations
# pruning operations
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: skip-file
"""Utility functions for building models."""
from __future__ import print_function
import collections
import os
import time
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from nlp_architect.models.gnmt.utils import iterator_utils
from nlp_architect.models.gnmt.utils import misc_utils as utils
from nlp_architect.models.gnmt.utils import vocab_utils
__all__ = [
"get_initializer", "get_device_str", "create_train_model",
"create_eval_model", "create_infer_model",
"create_emb_for_encoder_and_decoder", "create_rnn_cell", "gradient_clip",
"create_or_load_model", "load_model", "avg_checkpoints",
"compute_perplexity"
]
# If a vocab size is greater than this value, put the embedding on cpu instead
VOCAB_SIZE_THRESHOLD_CPU = 50000
# Collection for all the tensors involved in the quantization process
_QUANTIZATION_COLLECTION = 'qunatization'
[docs]def get_initializer(init_op, seed=None, init_weight=None):
"""Create an initializer. init_weight is only for uniform."""
if init_op == "uniform":
assert init_weight
return tf.random_uniform_initializer(
-init_weight, init_weight, seed=seed)
elif init_op == "glorot_normal":
return tf.keras.initializers.glorot_normal(
seed=seed)
elif init_op == "glorot_uniform":
return tf.keras.initializers.glorot_uniform(
seed=seed)
else:
raise ValueError("Unknown init_op %s" % init_op)
[docs]def get_device_str(device_id, num_gpus):
"""Return a device string for multi-GPU setup."""
if num_gpus == 0:
return "/cpu:0"
device_str_output = "/gpu:%d" % (device_id % num_gpus)
return device_str_output
class ExtraArgs(collections.namedtuple(
"ExtraArgs", ("single_cell_fn", "model_device_fn",
"attention_mechanism_fn", "encoder_emb_lookup_fn"))):
pass
class TrainModel(
collections.namedtuple("TrainModel", ("graph", "model", "iterator",
"skip_count_placeholder"))):
pass
[docs]def create_train_model(
model_creator, hparams, scope=None, num_workers=1, jobid=0,
extra_args=None):
"""Create train graph, model, and iterator."""
src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default(), tf.container(scope or "train"):
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file))
tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file))
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size=hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
random_seed=hparams.random_seed,
num_buckets=hparams.num_buckets,
src_max_len=hparams.src_max_len,
tgt_max_len=hparams.tgt_max_len,
skip_count=skip_count_placeholder,
num_shards=num_workers,
shard_index=jobid,
use_char_encode=hparams.use_char_encode)
# Note: One can set model_device_fn to
# `tf.train.replica_device_setter(ps_tasks)` for distributed training.
model_device_fn = None
if extra_args:
model_device_fn = extra_args.model_device_fn
with tf.device(model_device_fn):
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.TRAIN,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
scope=scope,
extra_args=extra_args)
return TrainModel(
graph=graph,
model=model,
iterator=iterator,
skip_count_placeholder=skip_count_placeholder)
class EvalModel(
collections.namedtuple("EvalModel",
("graph", "model", "src_file_placeholder",
"tgt_file_placeholder", "iterator"))):
pass
[docs]def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
"""Create train graph, model, src/tgt file holders, and iterator."""
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default(), tf.container(scope or "eval"):
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
tgt_vocab_file, default_value=vocab_utils.UNK)
src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
src_dataset = tf.data.TextLineDataset(src_file_placeholder)
tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
random_seed=hparams.random_seed,
num_buckets=hparams.num_buckets,
src_max_len=hparams.src_max_len_infer,
tgt_max_len=hparams.tgt_max_len_infer,
use_char_encode=hparams.use_char_encode)
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.EVAL,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
reverse_target_vocab_table=reverse_tgt_vocab_table,
scope=scope,
extra_args=extra_args)
return EvalModel(
graph=graph,
model=model,
src_file_placeholder=src_file_placeholder,
tgt_file_placeholder=tgt_file_placeholder,
iterator=iterator)
class InferModel(
collections.namedtuple("InferModel",
("graph", "model", "src_placeholder",
"batch_size_placeholder", "iterator"))):
pass
[docs]def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
"""Create inference model."""
graph = tf.Graph()
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
with graph.as_default(), tf.container(scope or "infer"):
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
tgt_vocab_file, default_value=vocab_utils.UNK)
src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)
src_dataset = tf.data.Dataset.from_tensor_slices(
src_placeholder)
iterator = iterator_utils.get_infer_iterator(
src_dataset,
src_vocab_table,
batch_size=batch_size_placeholder,
eos=hparams.eos,
src_max_len=hparams.src_max_len_infer,
use_char_encode=hparams.use_char_encode)
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.INFER,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
reverse_target_vocab_table=reverse_tgt_vocab_table,
scope=scope,
extra_args=extra_args)
return InferModel(
graph=graph,
model=model,
src_placeholder=src_placeholder,
batch_size_placeholder=batch_size_placeholder,
iterator=iterator)
def _get_embed_device(vocab_size):
"""Decide on which device to place an embed matrix given its vocab size."""
if vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
return "/cpu:0"
else:
return "/gpu:0"
def _create_pretrained_emb_from_txt(
vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32,
scope=None):
"""Load pretrain embeding from embed_file, and return an embedding matrix.
Args:
embed_file: Path to a Glove formated embedding txt file.
num_trainable_tokens: Make the first n tokens in the vocab file as trainable
variables. Default is 3, which is "<unk>", "<s>" and "</s>".
"""
vocab, _ = vocab_utils.load_vocab(vocab_file)
trainable_tokens = vocab[:num_trainable_tokens]
utils.print_out("# Using pretrained embedding: %s." % embed_file)
utils.print_out(" with trainable tokens: ")
emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
for token in trainable_tokens:
utils.print_out(" %s" % token)
if token not in emb_dict:
emb_dict[token] = [0.0] * emb_size
emb_mat = np.array(
[emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype())
emb_mat = tf.constant(emb_mat)
emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope:
with tf.device(_get_embed_device(num_trainable_tokens)):
emb_mat_var = tf.get_variable(
"emb_mat_var", [num_trainable_tokens, emb_size])
return tf.concat([emb_mat_var, emb_mat_const], 0)
def _create_or_load_embed(embed_name, vocab_file, embed_file,
vocab_size, embed_size, dtype, embed_type='dense'):
"""Create a new or load an existing embedding matrix."""
if vocab_file and embed_file:
embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file)
else:
with tf.device(_get_embed_device(vocab_size)):
if embed_type == 'dense':
embedding = tf.get_variable(
embed_name, [vocab_size, embed_size], dtype)
elif embed_type == 'sparse':
embedding = tf.get_variable(
embed_name, [vocab_size, embed_size], dtype)
embedding = tf.contrib.model_pruning.apply_mask(embedding, embed_name)
else:
raise ValueError("Unknown embedding type %s!" % embed_type)
return embedding
[docs]def create_emb_for_encoder_and_decoder(share_vocab,
src_vocab_size,
tgt_vocab_size,
src_embed_size,
tgt_embed_size,
embed_type='dense',
dtype=tf.float32,
num_enc_partitions=0,
num_dec_partitions=0,
src_vocab_file=None,
tgt_vocab_file=None,
src_embed_file=None,
tgt_embed_file=None,
use_char_encode=False,
scope=None):
"""Create embedding matrix for both encoder and decoder.
Args:
share_vocab: A boolean. Whether to share embedding matrix for both
encoder and decoder.
src_vocab_size: An integer. The source vocab size.
tgt_vocab_size: An integer. The target vocab size.
src_embed_size: An integer. The embedding dimension for the encoder's
embedding.
tgt_embed_size: An integer. The embedding dimension for the decoder's
embedding.
dtype: dtype of the embedding matrix. Default to float32.
num_enc_partitions: number of partitions used for the encoder's embedding
vars.
num_dec_partitions: number of partitions used for the decoder's embedding
vars.
scope: VariableScope for the created subgraph. Default to "embedding".
Returns:
embedding_encoder: Encoder's embedding matrix.
embedding_decoder: Decoder's embedding matrix.
Raises:
ValueError: if use share_vocab but source and target have different vocab
size.
"""
if num_enc_partitions <= 1:
enc_partitioner = None
else:
# Note: num_partitions > 1 is required for distributed training due to
# embedding_lookup tries to colocate single partition-ed embedding variable
# with lookup ops. This may cause embedding variables being placed on worker
# jobs.
enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions)
if num_dec_partitions <= 1:
dec_partitioner = None
else:
# Note: num_partitions > 1 is required for distributed training due to
# embedding_lookup tries to colocate single partition-ed embedding variable
# with lookup ops. This may cause embedding variables being placed on worker
# jobs.
dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions)
if src_embed_file and enc_partitioner:
raise ValueError(
"Can't set num_enc_partitions > 1 when using pretrained encoder "
"embedding")
if tgt_embed_file and dec_partitioner:
raise ValueError(
"Can't set num_dec_partitions > 1 when using pretrained decdoer "
"embedding")
with tf.variable_scope(scope or "embeddings", dtype=dtype, partitioner=enc_partitioner):
# Share embedding
if share_vocab:
if src_vocab_size != tgt_vocab_size:
raise ValueError("Share embedding but different src/tgt vocab sizes"
" %d vs. %d" % (src_vocab_size, tgt_vocab_size))
assert src_embed_size == tgt_embed_size
utils.print_out("# Use the same embedding for source and target")
vocab_file = src_vocab_file or tgt_vocab_file
embed_file = src_embed_file or tgt_embed_file
embedding_encoder = _create_or_load_embed(
"embedding_share", vocab_file, embed_file,
src_vocab_size, src_embed_size, dtype, embed_type=embed_type)
embedding_decoder = embedding_encoder
else:
if not use_char_encode:
with tf.variable_scope("encoder", partitioner=enc_partitioner):
embedding_encoder = _create_or_load_embed(
"embedding_encoder", src_vocab_file, src_embed_file,
src_vocab_size, src_embed_size, dtype, embed_type=embed_type)
else:
embedding_encoder = None
with tf.variable_scope("decoder", partitioner=dec_partitioner):
embedding_decoder = _create_or_load_embed(
"embedding_decoder", tgt_vocab_file, tgt_embed_file,
tgt_vocab_size, tgt_embed_size, dtype, embed_type=embed_type)
return embedding_encoder, embedding_decoder
def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
residual_connection=False, device_str=None, residual_fn=None):
"""Create an instance of a single RNN cell."""
# dropout (= 1 - keep_prob) is set to 0 during eval and infer
dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
# Cell Type
if unit_type == "lstm":
utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False)
single_cell = tf.contrib.rnn.BasicLSTMCell(
num_units,
forget_bias=forget_bias)
elif unit_type == "gru":
utils.print_out(" GRU", new_line=False)
single_cell = tf.contrib.rnn.GRUCell(num_units)
elif unit_type == "layer_norm_lstm":
utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias,
new_line=False)
single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
num_units,
forget_bias=forget_bias,
layer_norm=True)
elif unit_type == "nas":
utils.print_out(" NASCell", new_line=False)
single_cell = tf.contrib.rnn.NASCell(num_units)
elif unit_type == "mlstm":
utils.print_out(" Masked_LSTM, forget_bias=%g" % forget_bias, new_line=False)
single_cell = tf.contrib.model_pruning.MaskedBasicLSTMCell(
num_units,
forget_bias=forget_bias)
else:
raise ValueError("Unknown unit type %s!" % unit_type)
# Dropout (= 1 - keep_prob)
if dropout > 0.0:
single_cell = tf.contrib.rnn.DropoutWrapper(
cell=single_cell, input_keep_prob=(1.0 - dropout))
utils.print_out(" %s, dropout=%g " % (type(single_cell).__name__, dropout),
new_line=False)
# Residual
if residual_connection:
single_cell = tf.contrib.rnn.ResidualWrapper(
single_cell, residual_fn=residual_fn)
utils.print_out(" %s" % type(single_cell).__name__, new_line=False)
# Device Wrapper
if device_str:
single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
utils.print_out(" %s, device=%s" %
(type(single_cell).__name__, device_str), new_line=False)
return single_cell
def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
forget_bias, dropout, mode, num_gpus, base_gpu=0,
single_cell_fn=None, residual_fn=None):
"""Create a list of RNN cells."""
if not single_cell_fn:
single_cell_fn = _single_cell
# Multi-GPU
cell_list = []
for i in range(num_layers):
utils.print_out(" cell %d" % i, new_line=False)
single_cell = single_cell_fn(
unit_type=unit_type,
num_units=num_units,
forget_bias=forget_bias,
dropout=dropout,
mode=mode,
residual_connection=(i >= num_layers - num_residual_layers),
device_str=get_device_str(i + base_gpu, num_gpus),
residual_fn=residual_fn
)
utils.print_out("")
cell_list.append(single_cell)
return cell_list
[docs]def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers,
forget_bias, dropout, mode, num_gpus, base_gpu=0,
single_cell_fn=None):
"""Create multi-layer RNN cell.
Args:
unit_type: string representing the unit type, i.e. "lstm".
num_units: the depth of each unit.
num_layers: number of cells.
num_residual_layers: Number of residual layers from top to bottom. For
example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN
cells in the returned list will be wrapped with `ResidualWrapper`.
forget_bias: the initial forget bias of the RNNCell(s).
dropout: floating point value between 0.0 and 1.0:
the probability of dropout. this is ignored if `mode != TRAIN`.
mode: either tf.contrib.learn.TRAIN/EVAL/INFER
num_gpus: The number of gpus to use when performing round-robin
placement of layers.
base_gpu: The gpu device id to use for the first RNN cell in the
returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus`
as its device id.
single_cell_fn: allow for adding customized cell.
When not specified, we default to model_helper._single_cell
Returns:
An `RNNCell` instance.
"""
cell_list = _cell_list(unit_type=unit_type,
num_units=num_units,
num_layers=num_layers,
num_residual_layers=num_residual_layers,
forget_bias=forget_bias,
dropout=dropout,
mode=mode,
num_gpus=num_gpus,
base_gpu=base_gpu,
single_cell_fn=single_cell_fn)
if len(cell_list) == 1: # Single layer.
return cell_list[0]
else: # Multi layers
return tf.contrib.rnn.MultiRNNCell(cell_list)
[docs]def gradient_clip(gradients, max_gradient_norm):
"""Clipping gradients of a model."""
clipped_gradients, gradient_norm = tf.clip_by_global_norm(
gradients, max_gradient_norm)
gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)]
gradient_norm_summary.append(
tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)))
return clipped_gradients, gradient_norm_summary, gradient_norm
def print_variables_in_ckpt(ckpt_path):
"""Print a list of variables in a checkpoint together with their shapes."""
utils.print_out("# Variables in ckpt %s" % ckpt_path)
reader = tf.train.NewCheckpointReader(ckpt_path)
variable_map = reader.get_variable_to_shape_map()
for key in sorted(variable_map.keys()):
utils.print_out(" %s: %s" % (key, variable_map[key]))
[docs]def load_model(model, ckpt_path, session, name):
"""Load model from a checkpoint."""
start_time = time.time()
try:
model.saver.restore(session, ckpt_path)
except tf.errors.NotFoundError as e:
utils.print_out("Can't load checkpoint")
print_variables_in_ckpt(ckpt_path)
utils.print_out("%s" % str(e))
session.run(tf.tables_initializer())
utils.print_out(
" loaded %s model parameters from %s, time %.2fs" %
(name, ckpt_path, time.time() - start_time))
return model
def load_quantized_model(model, ckpt_path, session, name):
"""Loads quantized model and dequantizes variables"""
start_time = time.time()
dequant_ops = []
for tsr in tf.trainable_variables():
with tf.variable_scope(tsr.name.split(':')[0], reuse=True):
quant_tsr = tf.get_variable('quantized', dtype=tf.qint8)
min_range = tf.get_variable('min_range')
max_range = tf.get_variable('max_range')
dequant_ops.append(
tsr.assign(tf.dequantize(quant_tsr, min_range, max_range, 'SCALED')))
restore_list = [tsr for tsr in tf.global_variables() if tsr not in tf.trainable_variables()]
saver = tf.train.Saver(restore_list)
try:
saver.restore(session, ckpt_path)
except tf.errors.NotFoundError as e:
utils.print_out("Can't load checkpoint")
print_variables_in_ckpt(ckpt_path)
utils.print_out("%s" % str(e))
session.run(tf.tables_initializer())
session.run(dequant_ops)
utils.print_out(
" loaded %s model parameters from %s, time %.2fs" %
(name, ckpt_path, time.time() - start_time))
return model
def add_quatization_variables(model):
"""Add to graph quantization variables"""
with model.graph.as_default():
for tsr in tf.trainable_variables():
with tf.variable_scope(tsr.name.split(':')[0]):
output, min_range, max_range = tf.quantize(
tsr, tf.reduce_min(tsr), tf.reduce_max(tsr), tf.qint8, mode='SCALED')
tf.get_variable('quantized', initializer=output, trainable=False,
collections=[_QUANTIZATION_COLLECTION])
tf.get_variable('min_range', initializer=min_range, trainable=False,
collections=[_QUANTIZATION_COLLECTION])
tf.get_variable('max_range', initializer=max_range, trainable=False,
collections=[_QUANTIZATION_COLLECTION])
def quantize_checkpoint(session, ckpt_path):
"""Quantize current loaded model and saves checkpoint in ckpt_path"""
save_list = [tsr for tsr in tf.global_variables() if tsr not in tf.trainable_variables()]
saver = tf.train.Saver(save_list)
session.run(tf.variables_initializer(tf.get_collection(_QUANTIZATION_COLLECTION)))
saver.save(session, ckpt_path)
utils.print_out('Saved quantized checkpoint as %s' % ckpt_path)
[docs]def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
global_step_name):
"""Average the last N checkpoints in the model_dir."""
checkpoint_state = tf.train.get_checkpoint_state(model_dir)
if not checkpoint_state:
utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
return None
# Checkpoints are ordered from oldest to newest.
checkpoints = (
checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
if len(checkpoints) < num_last_checkpoints:
utils.print_out(
"# Skipping averaging checkpoints because not enough checkpoints is "
"avaliable."
)
return None
avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
if not tf.gfile.Exists(avg_model_dir):
utils.print_out(
"# Creating new directory %s for saving averaged checkpoints." %
avg_model_dir)
tf.gfile.MakeDirs(avg_model_dir)
utils.print_out("# Reading and averaging variables in checkpoints:")
var_list = tf.contrib.framework.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if name != global_step_name:
var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
utils.print_out(" %s" % checkpoint)
reader = tf.contrib.framework.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
for name in var_values:
var_values[name] /= len(checkpoints)
# Build a graph with same variables in the checkpoints, and save the averaged
# variables into the avg_model_dir.
with tf.Graph().as_default():
tf_vars = [
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
for v in var_values]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
tf.Variable(global_step, name=global_step_name, trainable=False)
saver = tf.train.Saver(tf.all_variables())
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
six.iteritems(var_values)):
sess.run(assign_op, {p: value})
# Use the built saver to save the averaged checkpoint. Only keep 1
# checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
saver.save(
sess,
os.path.join(avg_model_dir, "translate.ckpt"))
return avg_model_dir
[docs]def create_or_load_model(model, model_dir, session, name):
"""Create translation model and initialize or load parameters in session."""
latest_ckpt = tf.train.latest_checkpoint(model_dir)
if latest_ckpt:
model = load_model(model, latest_ckpt, session, name)
else:
start_time = time.time()
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
utils.print_out(" created %s model with fresh parameters, time %.2fs" %
(name, time.time() - start_time))
global_step = model.global_step.eval(session=session)
return model, global_step
[docs]def compute_perplexity(model, sess, name):
"""Compute perplexity of the output of the model.
Args:
model: model for compute perplexity.
sess: tensorflow session to use.
name: name of the batch.
Returns:
The perplexity of the eval outputs.
"""
total_loss = 0
total_predict_count = 0
start_time = time.time()
while True:
try:
output_tuple = model.eval(sess)
total_loss += output_tuple.eval_loss * output_tuple.batch_size
total_predict_count += output_tuple.predict_count
except tf.errors.OutOfRangeError:
break
perplexity = utils.safe_exp(total_loss / total_predict_count)
utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity),
start_time)
return perplexity