import os
import tempfile
import pkg_resources
import torch
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.common.environment import environment_config
try:
from torch.hub import download_url_to_file, load_state_dict_from_url
except (ModuleNotFoundError, ImportError, NameError):
from torch.hub import _download_url_to_file as download_url_to_file
logger = get_logger(__name__)
def get_checkpoints_dir_path(experiment_name: str, ckpt_root_dir: str = None):
"""Creating the checkpoint directory of a given experiment.
:param experiment_name: Name of the experiment.
:param ckpt_root_dir: Local root directory path where all experiment logging directories will
reside. When none is give, it is assumed that pkg_resources.resource_filename('checkpoints', "")
exists and will be used.
:return: checkpoints_dir_path
"""
if ckpt_root_dir:
return os.path.join(ckpt_root_dir, experiment_name)
elif os.path.exists(environment_config.PKG_CHECKPOINTS_DIR):
return os.path.join(environment_config.PKG_CHECKPOINTS_DIR, experiment_name)
else:
raise ValueError("Illegal checkpoints directory: pass ckpt_root_dir that exists, or add 'checkpoints' to resources.")
def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
"""
Gets the local path to the checkpoint file, which will be:
- By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
- if the checkpoint file is remotely located:
when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
- external_checkpoint_path when external_checkpoint_path != None
@param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
@param experiment_name: experiment name attr in trainer
@param ckpt_name: checkpoint filename
@param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
@return:
"""
if external_checkpoint_path:
return external_checkpoint_path
else:
checkpoints_folder_name = source_ckpt_folder_name or experiment_name
checkpoints_dir_path = get_checkpoints_dir_path(checkpoints_folder_name)
return os.path.join(checkpoints_dir_path, ckpt_name)
def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str):
"""
Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
@param net: (nn.Module) to load state_dict to
@param state_dict: (dict) Chekpoint state_dict
@param strict: (str) key matching strictness
@return:
"""
try:
net.load_state_dict(state_dict["net"] if "net" in state_dict.keys() else state_dict, strict=strict)
except (RuntimeError, ValueError, KeyError) as ex:
if strict == "no_key_matching":
adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
net.load_state_dict(adapted_state_dict["net"], strict=True)
else:
raise_informative_runtime_error(net.state_dict(), state_dict, ex)
@explicit_params_validation(validation_type="None")
def copy_ckpt_to_local_folder(
local_ckpt_destination_dir: str,
ckpt_filename: str,
remote_ckpt_source_dir: str = None,
path_src: str = "local",
overwrite_local_ckpt: bool = False,
load_weights_only: bool = False,
):
"""
Copy the checkpoint from any supported source to a local destination path
:param local_ckpt_destination_dir: destination where the checkpoint will be saved to
:param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
:param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
:param path_src: S3 / url
:param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
:return: Path to checkpoint
"""
ckpt_file_full_local_path = None
# IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
if not overwrite_local_ckpt:
# CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
download_ckpt_destination_dir = tempfile.gettempdir()
print(
"PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False "
"-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART"
)
else:
# SAVE THE CHECKPOINT TO MODEL's FOLDER
download_ckpt_destination_dir = pkg_resources.resource_filename("checkpoints", local_ckpt_destination_dir)
if path_src.startswith("s3"):
model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
# DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
ckpt_source_remote_dir=remote_ckpt_source_dir,
ckpt_destination_local_dir=download_ckpt_destination_dir,
ckpt_file_name=ckpt_filename,
overwrite_local_checkpoints_file=overwrite_local_ckpt,
)
if not load_weights_only:
# COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
model_checkpoints_data_interface.load_all_remote_log_files(
model_name=remote_ckpt_source_dir, model_checkpoint_local_dir=download_ckpt_destination_dir
)
if path_src == "url":
ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
# DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
return ckpt_file_full_local_path
def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
if not os.path.exists(ckpt_path):
raise ValueError("Incorrect Checkpoint path")
if device == "cuda":
state_dict = torch.load(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
return state_dict
[docs]def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable = None):
"""
Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
:param model_state_dict: the model state_dict
:param source_ckpt: checkpoint dict
:param exclude optional list for excluded layers
:param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
that returns a desired weight for ckpt_val.
:return: renamed checkpoint dict (if possible)
"""
if "net" in source_ckpt.keys():
source_ckpt = source_ckpt["net"]
model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
new_ckpt_dict = {}
for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
if solver is not None:
ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
if ckpt_val.shape != model_val.shape:
raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
new_ckpt_dict[model_key] = ckpt_val
return {"net": new_ckpt_dict}
def load_checkpoint_to_model(
ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str, load_weights_only: bool, load_ema_as_net: bool = False
):
"""
Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
@param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
@param ckpt_local_path: local path to the checkpoint file
@param load_backbone: whether to load the checkpoint as a backbone
@param net: network to load the checkpoint to
@param strict:
@param load_weights_only:
@return:
"""
if ckpt_local_path is None or not os.path.exists(ckpt_local_path):
error_msg = "Error - loading Model Checkpoint: Path {} does not exist".format(ckpt_local_path)
raise RuntimeError(error_msg)
if load_backbone and not hasattr(net, "backbone"):
raise ValueError("No backbone attribute in net - Can't load backbone weights")
# LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
if load_ema_as_net:
if "ema_net" not in checkpoint.keys():
raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
else:
checkpoint["net"] = checkpoint["ema_net"]
# LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
if load_backbone:
adaptive_load_state_dict(net.backbone, checkpoint, strict)
else:
adaptive_load_state_dict(net, checkpoint, strict)
message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint."
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
[checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"]
return checkpoint
class MissingPretrainedWeightsException(Exception):
"""Exception raised by unsupported pretrianed model.
Attributes:
message -- explanation of the error
"""
def __init__(self, desc):
self.message = "Missing pretrained wights: " + desc
super().__init__(self.message)
def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
"""
Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
"""
if (
ckpt_val.shape != model_val.shape
and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight"
and model_key == "_backbone._modules_list.0.conv.weight"
):
model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
model_val.data[:, :, 1::2, 1::2] = ckpt_val.data[:, 9:12]
replacement = model_val
else:
replacement = ckpt_val
return replacement
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
@param architecture: name of the model's architecture
@param model: model to load pretrinaed weights for
@param pretrained_weights: name for the pretrianed weights (i.e imagenet)
@return: None
"""
model_url_key = architecture + "_" + str(pretrained_weights)
if model_url_key not in MODEL_URLS.keys():
raise MissingPretrainedWeightsException(model_url_key)
url = MODEL_URLS[model_url_key]
unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)
def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = _yolox_ckpt_solver if "yolox" in architecture else None
adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(
model_state_dict=model.state_dict(), source_ckpt=pretrained_state_dict, solver=solver
)
model.load_state_dict(adapted_pretrained_state_dict["net"], strict=False)
def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
@param architecture: name of the model's architecture
@param model: model to load pretrinaed weights for
@param pretrained_weights: path tp pretrained weights
@return: None
"""
map_location = torch.device("cpu")
pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)