Source code for super_gradients.training.utils.sg_model_utils

import os
import sys
import socket
import time
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union

import torch
from torch.utils.tensorboard import SummaryWriter

from super_gradients.training.exceptions.dataset_exceptions import UnsupportedBatchItemsFormat

# TODO: These utils should move to sg_model package as internal (private) helper functions


[docs]def try_port(port): """ try_port - Helper method for tensorboard port binding :param port: :return: """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) is_port_available = False try: sock.bind(("localhost", port)) is_port_available = True except Exception as ex: print('Port ' + str(port) + ' is in use' + str(ex)) sock.close() return is_port_available
[docs]def launch_tensorboard_process(checkpoints_dir_path: str, sleep_postpone: bool = True, port: int = None) -> Tuple[Process, int]: """ launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them unless port is defined by the user :param checkpoints_dir_path: :param sleep_postpone: :param port: :return: tuple of tb process, port """ logdir_path = str(Path(checkpoints_dir_path).parent.absolute()) tb_cmd = 'tensorboard --logdir=' + logdir_path + ' --bind_all' if port is not None: tb_ports = [port] else: tb_ports = range(6006, 6016) for tb_port in tb_ports: if not try_port(tb_port): continue else: print('Starting Tensor-Board process on port: ' + str(tb_port)) tensor_board_process = Process(target=os.system, args=([tb_cmd + ' --port=' + str(tb_port)])) tensor_board_process.daemon = True tensor_board_process.start() # LET THE TENSORBOARD PROCESS START if sleep_postpone: time.sleep(3) return tensor_board_process, tb_port # RETURNING IRRELEVANT VALUES print('Failed to initialize Tensor-Board process on port: ' + ', '.join(map(str, tb_ports))) return None, -1
[docs]def init_summary_writer(tb_dir, checkpoint_loaded, user_prompt=False): """Remove previous tensorboard files from directory and launch a tensor board process""" # If the training is from scratch, Walk through destination folder and delete existing tensorboard logs user = '' if not checkpoint_loaded: for filename in os.listdir(tb_dir): if 'events' in filename: if not user_prompt: print('"{}" will not be deleted'.format(filename)) continue while True: # Verify with user before deleting old tensorboard files user = input('\nOLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:\n"{}"\n' 'DO YOU WANT TO DELETE THEM? [y/n]' .format(filename)) if (user != 'n' or user != 'y') else user if user == 'y': os.remove('{}/{}'.format(tb_dir, filename)) print('DELETED: {}!'.format(filename)) break elif user == 'n': print('"{}" will not be deleted'.format(filename)) break print('Unknown answer...') # Launch a tensorboard process return SummaryWriter(tb_dir)
[docs]def init_log_file(log_dir, hpmstructs=[], special_conf={}): """Create a file and document the HpmStruct as plain text""" filename = '{}/log_{}.txt'.format(log_dir, time.strftime('%b%d_%H_%M_%S', time.localtime())) log_file = open(filename, 'a' if os.path.exists(log_dir) else 'w') for hpm in hpmstructs: for key, val in hpm.__dict__.items(): log_file.write('{}: {}\n'.format(key, val)) for key, val in special_conf.items(): log_file.write('{}: {}\n'.format(key, val)) log_file.write('\nTraining process:\n-----------------') log_file.close() return filename
[docs]def add_log_to_file(filename, results_titles_list, results_values_list, epoch, max_epochs): """Add a message to the log file""" # -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes with open(filename, 'a') as f: f.write('\nEpoch (%d/%d) - ' % (epoch, max_epochs)) for result_title, result_value in zip(results_titles_list, results_values_list): if isinstance(result_value, torch.Tensor): result_value = result_value.item() f.write(result_title + ': ' + str(result_value) + '\t')
[docs]def write_training_results(writer, results_titles_list, results_values_list, epoch): """Stores the training and validation loss and accuracy for current epoch in a tensorboard file""" for res_key, res_val in zip(results_titles_list, results_values_list): # USE ONLY LOWER-CASE LETTERS AND REPLACE SPACES WITH '_' TO AVOID MANY TITLES FOR THE SAME KEY corrected_res_key = res_key.lower().replace(' ', '_') writer.add_scalar(corrected_res_key, res_val, epoch) writer.flush()
[docs]def write_hpms(writer, hpmstructs=[], special_conf={}): """Stores the training and dataset hyper params in the tensorboard file""" hpm_string = "" for hpm in hpmstructs: for key, val in hpm.__dict__.items(): hpm_string += '{}: {} \n '.format(key, val) for key, val in special_conf.items(): hpm_string += '{}: {} \n '.format(key, val) writer.add_text("Hyper_parameters", hpm_string) writer.flush()
# TODO: This should probably move into datasets/datasets_utils.py?
[docs]def unpack_batch_items(batch_items: Union[tuple, torch.Tensor]): """ Adds support for unpacking batch items in train/validation loop. @param batch_items: (Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of the following formats: 1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2 2. tuple: (inputs, targets, additional_batch_items) where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through the phase context under the attribute additional_batch_item_i_name, using a phase callback. @return: inputs, target, additional_batch_items """ additional_batch_items = {} if len(batch_items) == 2: inputs, target = batch_items elif len(batch_items) == 3: inputs, target, additional_batch_items = batch_items else: raise UnsupportedBatchItemsFormat() return inputs, target, additional_batch_items
[docs]def log_uncaught_exceptions(logger): """ Makes logger log uncaught exceptions @param logger: logging.Logger @return: None """ def handle_exception(exc_type, exc_value, exc_traceback): if issubclass(exc_type, KeyboardInterrupt): sys.__excepthook__(exc_type, exc_value, exc_traceback) return logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) sys.excepthook = handle_exception