# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
# import paths
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: skip-file
"""Generally useful utility functions."""
from __future__ import print_function
import codecs
import collections
import json
import math
import os
import sys
import time
from distutils import version
import tensorflow as tf
[docs]def check_tensorflow_version():
min_tf_version = "1.4.0-dev20171024"
if (version.LooseVersion(tf.__version__) < version.LooseVersion(min_tf_version)):
raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version)
[docs]def safe_exp(value):
"""Exponentiation with catching of overflow error."""
try:
ans = math.exp(value)
except OverflowError:
ans = float("inf")
return ans
[docs]def print_time(s, start_time):
"""Take a start time, print elapsed duration, and return a new time."""
print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime()))
sys.stdout.flush()
return time.time()
[docs]def print_out(s, f=None, new_line=True):
"""Similar to print but with support to flush and output to a file."""
if isinstance(s, bytes):
s = s.decode("utf-8")
if f:
f.write(s.encode("utf-8"))
if new_line:
f.write(b"\n")
# stdout
out_s = s.encode("utf-8")
if not isinstance(out_s, str):
out_s = out_s.decode("utf-8")
print(out_s, end="", file=sys.stdout)
if new_line:
sys.stdout.write("\n")
sys.stdout.flush()
[docs]def print_hparams(hparams, skip_patterns=None, header=None):
"""Print hparams, can skip keys based on pattern."""
if header:
print_out("%s" % header)
values = hparams.values()
for key in sorted(values.keys()):
if not skip_patterns or all(
[skip_pattern not in key for skip_pattern in skip_patterns]):
print_out(" %s=%s" % (key, str(values[key])))
[docs]def load_hparams(model_dir):
"""Load hparams from an existing model directory."""
hparams_file = os.path.join(model_dir, "hparams")
if tf.gfile.Exists(hparams_file):
print_out("# Loading hparams from %s" % hparams_file)
with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f:
try:
hparams_values = json.load(f)
hparams = tf.contrib.training.HParams(**hparams_values)
except ValueError:
print_out(" can't load hparams file")
return None
return hparams
else:
return None
[docs]def maybe_parse_standard_hparams(hparams, hparams_path):
"""Override hparams values with existing standard hparams config."""
if hparams_path and tf.gfile.Exists(hparams_path):
print_out("# Loading standard hparams from %s" % hparams_path)
with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_path, "rb")) as f:
hparams.parse_json(f.read())
return hparams
[docs]def save_hparams(out_dir, hparams):
"""Save hparams."""
hparams_file = os.path.join(out_dir, "hparams")
print_out(" saving hparams to %s" % hparams_file)
with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f:
f.write(hparams.to_json(indent=4, sort_keys=True))
[docs]def debug_tensor(s, msg=None, summarize=10):
"""Print the shape and value of a tensor at test time. Return a new tensor."""
if not msg:
msg = s.name
return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize)
[docs]def add_summary(summary_writer, global_step, tag, value):
"""Add a new summary to the current summary_writer.
Useful to log things that are not part of the training graph, e.g., tag=BLEU.
"""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
summary_writer.add_summary(summary, global_step)
[docs]def get_config_proto(log_device_placement=False, allow_soft_placement=True,
num_intra_threads=0, num_inter_threads=0):
# GPU options:
# https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html
config_proto = tf.ConfigProto(
log_device_placement=log_device_placement,
allow_soft_placement=allow_soft_placement)
config_proto.gpu_options.allow_growth = True
# CPU threads options
if num_intra_threads:
config_proto.intra_op_parallelism_threads = num_intra_threads
if num_inter_threads:
config_proto.inter_op_parallelism_threads = num_inter_threads
return config_proto
[docs]def format_text(words):
"""Convert a sequence words into sentence."""
if (not hasattr(words, "__len__") # for numpy array
and not isinstance(words, collections.Iterable)):
words = [words]
return b" ".join(words)
[docs]def format_bpe_text(symbols, delimiter=b"@@"):
"""Convert a sequence of bpe words into sentence."""
words = []
word = b""
if isinstance(symbols, str):
symbols = symbols.encode()
delimiter_len = len(delimiter)
for symbol in symbols:
if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter:
word += symbol[:-delimiter_len]
else: # end of a word
word += symbol
words.append(word)
word = b""
return b" ".join(words)
[docs]def format_spm_text(symbols):
"""Decode a text in SPM (https://github.com/google/sentencepiece) format."""
return u"".join(format_text(symbols).decode("utf-8").split()).replace(
u"\u2581", u" ").strip().encode("utf-8")