# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module with useful wrappers used for the development of InferPy.
"""
from functools import wraps
import inferpy as inf
import numpy as np
from six import iteritems
import pandas as pd
[docs]def tf_run_wrapper(f):
""" When setted to a function f, this wrappers replaces the output tensor of f by its evaluation
in the default tensorflow session. In doing so, the API user will only work with standard Python
types.
"""
@wraps(f)
def wrapper(*args, **kwargs):
if "tf_run" in kwargs:
tf_run = kwargs.pop("tf_run")
else:
tf_run = inf.util.Runtime.tf_run_default
if tf_run:
# transforms in a list
output_tf = f(*args, **kwargs)
if type(output_tf).__module__ == np.__name__ or type(output_tf).__name__ == list.__name__:
output_tf_vect = output_tf
elif type(output_tf).__name__ == dict.__name__:
output_tf_vect = list(output_tf.values())
else:
output_tf_vect = [output_tf]
# evaluation
output_eval_vect = inf.util.Runtime.tf_sess.run(output_tf_vect)
# transforms in original type
if type(output_tf).__module__ == np.__name__ or type(output_tf).__name__ == list.__name__:
output_eval = output_eval_vect
elif type(output_tf).__name__ == dict.__name__:
output_eval = {}
i = 0
for k, v in iteritems(output_tf):
output_eval.update({k: output_eval_vect[i]})
i = i+1
else:
output_eval = output_eval_vect[0]
return output_eval
return f(*args, **kwargs)
return wrapper
[docs]def multishape(f):
""" This wrapper allows to apply a function with simple parameters, over multidimensional ones. """
@wraps(f)
def wrapper(*args, **kwargs):
first_arg = 1
if np.ndim(args[first_arg]) == 0: # single element
return f(*args, **kwargs)
elif np.ndim(args[first_arg]) == 1: # unidimensional vector
output = []
for i in args[1]:
if first_arg == 1:
output.append(f(args[0], i, **kwargs))
else:
output.append(f(i, **kwargs))
return output
else:
raise ValueError("@multishape wrapper can only be applied to single elements or to 1-dimension vectors")
return wrapper
[docs]def static_multishape(f):
""" This wrapper allows to apply a function with simple parameters, over multidimensional ones. """
@wraps(f)
def wrapper(*args, **kwargs):
first_arg = 0
if np.ndim(args[first_arg]) == 0: # single element
return f(*args, **kwargs)
elif np.ndim(args[first_arg]) == 1: # unidimensional vector
output = []
for i in args[first_arg]:
if first_arg == 1:
output.append(f(args[0], i, **kwargs))
else:
output.append(f(i, **kwargs))
return output
else:
raise ValueError("@multishape wrapper can only be applied to single elements or to 1-dimension vectors")
return wrapper
[docs]def singleton(class_):
""" wrapper that allows to define a singleton class """
class class_w(class_):
_instance = None
def __new__(class_, *args, **kwargs):
if class_w._instance is None:
class_w._instance = super(class_w,
class_).__new__(class_,
*args,
**kwargs)
class_w._instance._sealed = False
return class_w._instance
def __init__(self, *args, **kwargs):
if self._sealed:
return
super(class_w, self).__init__(*args, **kwargs)
self._sealed = True
class_w.__name__ = class_.__name__
return class_w