Source code for super_gradients.common.environment.env_helpers

import argparse
import importlib
import os
import socket
import sys
from functools import wraps
from typing import Any

from omegaconf import OmegaConf

from super_gradients.common.environment import environment_config


class TerminalColours:
    """
    Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&tab=votes#tab-top
    """

    HEADER = "\033[95m"
    OKBLUE = "\033[94m"
    OKCYAN = "\033[96m"
    OKGREEN = "\033[92m"
    WARNING = "\033[93m"
    FAIL = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"


class ColouredTextFormatter:
    @staticmethod
    def print_coloured_text(text: str, colour: str):
        """
        Prints a text with colour ascii characters.
        """
        return print("".join([colour, text, TerminalColours.ENDC]))


def get_cls(cls_path):
    """
    A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
    usage:
    class_of_optimizer: ${class:torch.optim.Adam}
    """
    module = ".".join(cls_path.split(".")[:-1])
    name = cls_path.split(".")[-1]
    importlib.import_module(module)
    return getattr(sys.modules[module], name)


def get_environ_as_type(environment_variable_name: str, default=None, cast_to_type: type = str) -> object:
    """
    Tries to get an environment variable and cast it into a requested type.
    :return: cast_to_type object, or None if failed.
    :raises ValueError: If the value could not be casted into type 'cast_to_type'
    """
    value = os.environ.get(environment_variable_name, default)
    if value is not None:
        try:
            return cast_to_type(value)
        except Exception as e:
            print(e)
            raise ValueError(
                f"Failed to cast environment variable {environment_variable_name} to type {cast_to_type}: the value {value} is not a valid {cast_to_type}"
            )
    return


def hydra_output_dir_resolver(ckpt_root_dir, experiment_name):
    if ckpt_root_dir is None:
        output_dir_path = environment_config.PKG_CHECKPOINTS_DIR + os.path.sep + experiment_name
    else:
        output_dir_path = ckpt_root_dir + os.path.sep + experiment_name
    return output_dir_path


[docs]def init_trainer(): """ Initialize the super_gradients environment. This function should be the first thing to be called by any code running super_gradients. It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment. """ if not environment_config.INIT_TRAINER: register_hydra_resolvers() # We pop local_rank if it was specified in the args, because it would break args_local_rank = pop_arg("local_rank", default_value=-1) # Set local_rank with priority order (env variable > args.local_rank > args.default_value) environment_config.DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank)) environment_config.INIT_TRAINER = True
def register_hydra_resolvers(): """Register all the hydra resolvers required for the super-gradients recipes.""" OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True) OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True) OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True) OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True) OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...) OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list OmegaConf.register_new_resolver("last", lambda lst: lst[-1], replace=True) # get the last item from a list def pop_arg(arg_name: str, default_value: Any = None) -> Any: """Get the specified args and remove them from argv""" parser = argparse.ArgumentParser() parser.add_argument(f"--{arg_name}", default=default_value) args, _ = parser.parse_known_args() # Remove the ddp args to not have a conflict with the use of hydra for val in filter(lambda x: x.startswith(f"--{arg_name}"), sys.argv): environment_config.EXTRA_ARGS.append(val) sys.argv.remove(val) return vars(args)[arg_name]
[docs]def is_distributed() -> bool: return environment_config.DDP_LOCAL_RANK >= 0
def is_rank_0() -> bool: """Check if the node was launched with torch.distributed.launch and if the node is of rank 0""" return os.getenv("LOCAL_RANK") == "0" def is_launched_using_sg(): """Check if the current process is a subprocess launched using SG restart_script_with_ddp""" return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated" def is_main_process(): """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...). The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started. The rule is as follow: - If not DDP: main process is current process - If DDP launched using SuperGradients: main process is the launching process (rank=-1) - If DDP launched with torch: main process is rank 0 """ if not is_distributed(): # If no DDP, or DDP launching process return True elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0 return True else: return False def multi_process_safe(func): """ A decorator for making sure a function runs only in main process. If not in DDP mode (local_rank = -1), the function will run. If in DDP mode, the function will run only in the main process (local_rank = 0) This works only for functions with no return value """ def do_nothing(*args, **kwargs): pass @wraps(func) def wrapper(*args, **kwargs): if environment_config.DDP_LOCAL_RANK <= 0: return func(*args, **kwargs) else: return do_nothing(*args, **kwargs) return wrapper def find_free_port() -> int: """Find an available port of current machine/node. Note: there is still a chance the port could be taken by other processes.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) _ip, port = sock.getsockname() return port