import copy
import os
import time
from enum import Enum
import math
import numpy as np
import onnx
import onnxruntime
import torch
import signal
from typing import List
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
import cv2
logger = get_logger(__name__)
try:
from deci_lab_client.client import DeciPlatformClient
from deci_lab_client.models import ModelBenchmarkState
from deci_lab_client.models.model_metadata import ModelMetadata
_imported_deci_lab_failure = None
except (ImportError, NameError, ModuleNotFoundError) as import_err:
logger.warn("Failed to import deci_lab_client")
_imported_deci_lab_failure = import_err
[docs]class Phase(Enum):
PRE_TRAINING = "PRE_TRAINING"
TRAIN_BATCH_END = "TRAIN_BATCH_END"
TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
TRAIN_EPOCH_END = "TRAIN_EPOCH_END"
VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"
TEST_BATCH_END = "TEST_BATCH_END"
TEST_END = "TEST_END"
POST_TRAINING = "POST_TRAINING"
[docs]class ContextSgMethods:
"""
Class for delegating SgModel's methods, so that only the relevant ones are ("phase wise") are accessible.
"""
def __init__(self, **methods):
for attr, attr_val in methods.items():
setattr(self, attr, attr_val)
[docs]class PhaseContext:
"""
Represents the input for phase callbacks, and is constantly updated after callback calls.
"""
def __init__(self, epoch=None, batch_idx=None, optimizer=None, metrics_dict=None, inputs=None, preds=None,
target=None, metrics_compute_fn=None, loss_avg_meter=None, loss_log_items=None, criterion=None,
device=None, experiment_name=None, ckpt_dir=None, net=None, lr_warmup_epochs=None, sg_logger=None,
train_loader=None, valid_loader=None,
training_params=None, ddp_silent_mode=None, checkpoint_params=None, architecture=None,
arch_params=None, metric_idx_in_results_tuple=None,
metric_to_watch=None, valid_metrics=None, context_methods=None):
self.epoch = epoch
self.batch_idx = batch_idx
self.optimizer = optimizer
self.inputs = inputs
self.preds = preds
self.target = target
self.metrics_dict = metrics_dict
self.metrics_compute_fn = metrics_compute_fn
self.loss_avg_meter = loss_avg_meter
self.loss_log_items = loss_log_items
self.criterion = criterion
self.device = device
self.stop_training = False
self.experiment_name = experiment_name
self.ckpt_dir = ckpt_dir
self.net = net
self.lr_warmup_epochs = lr_warmup_epochs
self.sg_logger = sg_logger
self.train_loader = train_loader
self.valid_loader = valid_loader
self.training_params = training_params
self.ddp_silent_mode = ddp_silent_mode
self.checkpoint_params = checkpoint_params
self.architecture = architecture
self.arch_params = arch_params
self.metric_idx_in_results_tuple = metric_idx_in_results_tuple
self.metric_to_watch = metric_to_watch
self.valid_metrics = valid_metrics
self.context_methods = context_methods
[docs] def update_context(self, **kwargs):
for attr, attr_val in kwargs.items():
setattr(self, attr, attr_val)
[docs]class PhaseCallback:
def __init__(self, phase: Phase):
self.phase = phase
def __call__(self, *args, **kwargs):
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__
[docs]class ModelConversionCheckCallback(PhaseCallback):
"""
Pre-training callback that verifies model conversion to onnx given specified conversion parameters.
The model is converted, then inference is applied with onnx runtime.
Use this callback wit hthe same args as DeciPlatformCallback to prevent conversion fails at the end of training.
Attributes:
model_meta_data: (ModelMetadata) model's meta-data object.
The following parameters may be passed as kwargs in order to control the conversion to onnx:
:param opset_version (default=11)
:param do_constant_folding (default=True)
:param dynamic_axes (default=
{'input': {0: 'batch_size'},
# Variable length axes
'output': {0: 'batch_size'}}
)
:param input_names (default=["input"])
:param output_names (default=["output"])
:param rtol (default=1e-03)
:param atol (default=1e-05)
"""
def __init__(self, model_meta_data: ModelMetadata, **kwargs):
super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
self.model_meta_data = model_meta_data
self.opset_version = kwargs.get("opset_version", 10)
self.do_constant_folding = (
kwargs.get("do_constant_folding", None) if kwargs.get("do_constant_folding", None) else True
)
self.input_names = kwargs.get("input_names") or ["input"]
self.output_names = kwargs.get("output_names") or ["output"]
self.dynamic_axes = kwargs.get("dynamic_axes") or {"input": {0: "batch_size"}, "output": {0: "batch_size"}}
self.rtol = kwargs.get("rtol", 1e-03)
self.atol = kwargs.get("atol", 1e-05)
def __call__(self, context: PhaseContext):
model = copy.deepcopy(context.net.module)
model = model.cpu()
model.eval() # Put model into eval mode
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.model_meta_data.input_dimensions)
x = torch.randn(
self.model_meta_data.primary_batch_size, *self.model_meta_data.input_dimensions, requires_grad=False
)
tmp_model_path = os.path.join(context.ckpt_dir, self.model_meta_data.name + "_tmp.onnx")
with torch.no_grad():
torch_out = model(x)
torch.onnx.export(
model, # Model being run
x, # Model input (or a tuple for multiple inputs)
tmp_model_path, # Where to save the model (can be a file or file-like object)
export_params=True, # Store the trained parameter weights inside the model file
opset_version=self.opset_version,
do_constant_folding=self.do_constant_folding,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
onnx_model = onnx.load(tmp_model_path)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(
tmp_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
ort_outs = ort_session.run(None, ort_inputs)
# TODO: Ideally we don't want to check this but have the certainty of just calling torch_out.cpu()
if isinstance(torch_out, List) or isinstance(torch_out, tuple):
torch_out = torch_out[0]
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out.cpu().numpy(), ort_outs[0], rtol=self.rtol, atol=self.atol)
os.remove(tmp_model_path)
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
[docs]class DeciLabUploadCallback(PhaseCallback):
"""
Post-training callback for uploading and optimizing a model.
Attributes:
email: (str) username for Deci platform.
model_meta_data: (ModelMetadata) model's meta-data object.
optimization_request_form: (dict) optimization request form object.
password: (str) default=None, should only be used for testing.
ckpt_name: (str) default="ckpt_best" refers to the filename of the checkpoint, inside the checkpoint directory.
The following parameters may be passed as kwargs in order to control the conversion to onnx:
:param opset_version
:param do_constant_folding
:param dynamic_axes
:param input_names
:param output_names
"""
def __init__(self, model_meta_data, optimization_request_form, auth_token: str = None, ckpt_name="ckpt_best.pth", **kwargs):
super().__init__(phase=Phase.POST_TRAINING)
if _imported_deci_lab_failure is not None:
raise _imported_deci_lab_failure
self.model_meta_data = model_meta_data
self.optimization_request_form = optimization_request_form
self.conversion_kwargs = kwargs
self.ckpt_name = ckpt_name
self.platform_client = DeciPlatformClient("api.deci.ai", 443, https=True)
self.platform_client.login(token=auth_token)
[docs] @staticmethod
def log_optimization_failed():
logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")
[docs] def upload_model(self, model):
"""
This function will upload the trained model to the Deci Lab
Args:
model: The resulting model from the training process
"""
self.platform_client.add_model(
add_model_request=self.model_meta_data,
optimization_request=self.optimization_request_form,
local_loaded_model=model,
)
[docs] def get_optimization_status(self, optimized_model_name: str):
"""
This function will do fetch the optimized version of the trained model and check on its benchmark status.
The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
or log about the successful optimization - whichever happens first.
Args:
optimized_model_name (str): Optimized model name
Returns:
bool: whether or not the optimized model has been benchmarked
"""
def handler(_signum, _frame):
logger.error("Process timed out. Visit https://console.deci.ai for details")
return False
signal.signal(signal.SIGALRM, handler)
signal.alarm(1800)
finished = False
while not finished:
optimized_model = self.platform_client.get_model_by_name(name=optimized_model_name).data
if optimized_model.benchmark_state not in [ModelBenchmarkState.IN_PROGRESS, ModelBenchmarkState.PENDING]:
finished = True
else:
time.sleep(30)
signal.alarm(0)
return True
def __call__(self, context: PhaseContext):
"""
This function will attempt to upload the trained model and schedule an optimization for it.
Args:
context (PhaseContext): Training phase context
Returns:
bool: whether or not the optimized model has been benchmarked
"""
try:
model = copy.deepcopy(context.net)
model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
model_state_dict = torch.load(model_state_dict_path)["net"]
model.load_state_dict(state_dict=model_state_dict)
model = model.module.cpu()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.model_meta_data.input_dimensions)
self.upload_model(model=model)
model_name = self.model_meta_data.name
logger.info(f"Successfully added {model_name} to the model repository")
optimized_model_name = f"{model_name}_1_1"
logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
success = self.get_optimization_status(optimized_model_name=optimized_model_name)
if success:
logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
else:
DeciLabUploadCallback.log_optimization_failed()
except Exception as ex:
DeciLabUploadCallback.log_optimization_failed()
logger.error(ex)
[docs]class LRCallbackBase(PhaseCallback):
"""
Base class for hard coded learning rate scheduling regimes, implemented as callbacks.
"""
def __init__(self, phase, initial_lr, update_param_groups, train_loader_len, net, training_params, **kwargs):
super(LRCallbackBase, self).__init__(phase)
self.initial_lr = initial_lr
self.lr = initial_lr
self.update_param_groups = update_param_groups
self.train_loader_len = train_loader_len
self.net = net
self.training_params = training_params
def __call__(self, context: PhaseContext, **kwargs):
if self.is_lr_scheduling_enabled(context):
self.perform_scheduling(context)
[docs] def is_lr_scheduling_enabled(self, context: PhaseContext):
"""
Predicate that controls whether to perform lr scheduling based on values in context.
@param context: PhaseContext: current phase's context.
@return: bool, whether to apply lr scheduling or not.
"""
raise NotImplementedError
[docs] def update_lr(self, optimizer, epoch, batch_idx=None):
if self.update_param_groups:
param_groups = self.net.module.update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
)
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr
[docs]class WarmupLRCallback(LRCallbackBase):
"""
LR scheduling callback for linear step warmup.
LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None- LR climb starts from
initial_lr/(1+warmup_epochs).
"""
def __init__(self, **kwargs):
super(WarmupLRCallback, self).__init__(Phase.TRAIN_EPOCH_START, **kwargs)
self.warmup_initial_lr = self.training_params.warmup_initial_lr or self.initial_lr / (
self.training_params.lr_warmup_epochs + 1
)
self.warmup_step_size = (self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs
[docs] def is_lr_scheduling_enabled(self, context):
return self.training_params.lr_warmup_epochs >= context.epoch
[docs]class StepLRCallback(LRCallbackBase):
"""
Hard coded step learning rate scheduling (i.e at specific milestones).
"""
def __init__(self, lr_updates, lr_decay_factor, step_lr_update_freq=None, **kwargs):
super(StepLRCallback, self).__init__(Phase.TRAIN_EPOCH_END, **kwargs)
if step_lr_update_freq and len(lr_updates):
raise ValueError(
"Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRCallback constructor"
)
if step_lr_update_freq:
max_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
warmup_epochs = self.training_params.lr_warmup_epochs
lr_updates = [
int(np.ceil(step_lr_update_freq * x))
for x in range(1, max_epochs)
if warmup_epochs <= int(np.ceil(step_lr_update_freq * x)) < max_epochs
]
elif self.training_params.lr_cooldown_epochs > 0:
logger.warning(
"Specific lr_updates were passed along with cooldown_epochs > 0," " cooldown will have no effect."
)
self.lr_updates = lr_updates
self.lr_decay_factor = lr_decay_factor
[docs] def is_lr_scheduling_enabled(self, context):
return self.training_params.lr_warmup_epochs <= context.epoch
[docs]class ExponentialLRCallback(LRCallbackBase):
"""
Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.
"""
def __init__(self, lr_decay_factor: float, **kwargs):
super().__init__(phase=Phase.TRAIN_BATCH_STEP, **kwargs)
self.lr_decay_factor = lr_decay_factor
[docs] def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
[docs]class PolyLRCallback(LRCallbackBase):
"""
Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).
"""
def __init__(self, max_epochs, **kwargs):
super(PolyLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
self.max_epochs = max_epochs
[docs] def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
[docs]class CosineLRCallback(LRCallbackBase):
"""
Hard coded step Cosine anealing learning rate scheduling.
"""
def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
super(CosineLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
self.max_epochs = max_epochs
self.cosine_final_lr_ratio = cosine_final_lr_ratio
[docs] def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
[docs]class FunctionLRCallback(LRCallbackBase):
"""
Hard coded rate scheduling for user defined lr scheduling function.
"""
def __init__(self, max_epochs, lr_schedule_function, **kwargs):
super(FunctionLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
assert callable(self.lr_schedule_function), "self.lr_function must be callable"
self.lr_schedule_function = lr_schedule_function
self.max_epochs = max_epochs
[docs] def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
[docs]class IllegalLRSchedulerMetric(Exception):
"""Exception raised illegal combination of training parameters.
Attributes:
message -- explanation of the error
"""
def __init__(self, metric_name, metrics_dict):
self.message = (
"Illegal metric name: " + metric_name + ". Expected one of metics_dics keys: " + str(metrics_dict.keys())
)
super().__init__(self.message)
[docs]class LRSchedulerCallback(PhaseCallback):
"""
Learning rate scheduler callback.
Attributes:
scheduler: torch.optim._LRScheduler, the learning rate scheduler to be called step() with.
metric_name: str, (default=None) the metric name for ReduceLROnPlateau learning rate scheduler.
When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).
"""
def __init__(self, scheduler, phase, metric_name=None):
super(LRSchedulerCallback, self).__init__(phase)
self.scheduler = scheduler
self.metric_name = metric_name
def __call__(self, context: PhaseContext):
if context.lr_warmup_epochs <= context.epoch:
if self.metric_name and self.metric_name in context.metrics_dict.keys():
self.scheduler.step(context.metrics_dict[self.metric_name])
elif self.metric_name is None:
self.scheduler.step()
else:
raise IllegalLRSchedulerMetric(self.metric_name, context.metrics_dict)
def __repr__(self):
return "LRSchedulerCallback: " + repr(self.scheduler)
[docs]class MetricsUpdateCallback(PhaseCallback):
def __init__(self, phase: Phase):
super(MetricsUpdateCallback, self).__init__(phase)
def __call__(self, context: PhaseContext):
context.metrics_compute_fn.update(**context.__dict__)
if context.criterion is not None:
context.loss_avg_meter.update(context.loss_log_items, len(context.inputs))
[docs]class KDModelMetricsUpdateCallback(MetricsUpdateCallback):
def __init__(self, phase: Phase):
super().__init__(phase=phase)
def __call__(self, context: PhaseContext):
metrics_compute_fn_kwargs = {k: v.student_output if k == "preds" else v for k, v in context.__dict__.items()}
context.metrics_compute_fn.update(**metrics_compute_fn_kwargs)
if context.criterion is not None:
context.loss_avg_meter.update(context.loss_log_items, len(context.inputs))
[docs]class PhaseContextTestCallback(PhaseCallback):
"""
A callback that saves the phase context the for testing.
"""
def __init__(self, phase: Phase):
super(PhaseContextTestCallback, self).__init__(phase)
self.context = None
def __call__(self, context: PhaseContext):
self.context = context
[docs]class DetectionVisualizationCallback(PhaseCallback):
"""
A callback that adds a visualization of a batch of detection predictions to context.sg_logger
Attributes:
freq: frequency (in epochs) to perform this callback.
batch_idx: batch index to perform visualization for.
classes: class list of the dataset.
last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).
"""
def __init__(
self,
phase: Phase,
freq: int,
post_prediction_callback: DetectionPostPredictionCallback,
classes: list,
batch_idx: int = 0,
last_img_idx_in_batch: int = -1,
):
super(DetectionVisualizationCallback, self).__init__(phase)
self.freq = freq
self.post_prediction_callback = post_prediction_callback
self.batch_idx = batch_idx
self.classes = classes
self.last_img_idx_in_batch = last_img_idx_in_batch
def __call__(self, context: PhaseContext):
if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
# SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS
preds = (context.preds[0].clone(), None)
preds = self.post_prediction_callback(preds)
batch_imgs = DetectionVisualization.visualize_batch(
context.inputs, preds, context.target, self.batch_idx, self.classes
)
batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
batch_imgs = np.stack(batch_imgs)
tag = "batch_" + str(self.batch_idx) + "_images"
context.sg_logger.add_images(
tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC"
)
[docs]class BinarySegmentationVisualizationCallback(PhaseCallback):
"""
A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
Attributes:
freq: frequency (in epochs) to perform this callback.
batch_idx: batch index to perform visualization for.
last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).
"""
def __init__(self, phase: Phase, freq: int, batch_idx: int = 0, last_img_idx_in_batch: int = -1):
super(BinarySegmentationVisualizationCallback, self).__init__(phase)
self.freq = freq
self.batch_idx = batch_idx
self.last_img_idx_in_batch = last_img_idx_in_batch
def __call__(self, context: PhaseContext):
if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
if isinstance(context.preds, tuple):
preds = context.preds[0].clone()
else:
preds = context.preds.clone()
batch_imgs = BinarySegmentationVisualization.visualize_batch(
context.inputs, preds, context.target, self.batch_idx
)
batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
batch_imgs = np.stack(batch_imgs)
tag = "batch_" + str(self.batch_idx) + "_images"
context.sg_logger.add_images(tag=tag, images=batch_imgs[:self.last_img_idx_in_batch],
global_step=context.epoch, data_format='NHWC')
[docs]class TrainingStageSwitchCallbackBase(PhaseCallback):
"""
TrainingStageSwitchCallback
A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
It does so by manipulating the objects inside the context.
Attributes:
next_stage_start_epoch: int, the epoch idx to apply the stage change.
"""
def __init__(self, next_stage_start_epoch: int):
super(TrainingStageSwitchCallbackBase, self).__init__(phase=Phase.TRAIN_EPOCH_START)
self.next_stage_start_epoch = next_stage_start_epoch
def __call__(self, context: PhaseContext):
if context.epoch == self.next_stage_start_epoch:
self.apply_stage_change(context)
[docs] def apply_stage_change(self, context: PhaseContext):
"""
This method is called when the callback is fired on the next_stage_start_epoch,
and holds the stage change logic that should be applied to the context's objects.
:param context: PhaseContext, context of current phase
"""
raise NotImplementedError
[docs]class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
"""
YoloXTrainingStageSwitchCallback
Training stage switch for YoloX training.
Disables mosaic, and manipulates YoloX loss to use L1.
"""
def __init__(self, next_stage_start_epoch: int = 285):
super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)
[docs] def apply_stage_change(self, context: PhaseContext):
for transform in context.train_loader.dataset.transforms:
if hasattr(transform, "close"):
transform.close()
iter(context.train_loader)
context.criterion.use_l1 = True
[docs]class CallbackHandler:
"""
Runs all callbacks who's phase attribute equals to the given phase.
Attributes:
callbacks: List[PhaseCallback]. Callbacks to be run.
"""
def __init__(self, callbacks):
self.callbacks = callbacks
def __call__(self, phase: Phase, context: PhaseContext):
for callback in self.callbacks:
if callback.phase == phase:
callback(context)
# DICT FOR LEGACY LR HARD-CODED REGIMES, WILL BE DELETED IN THE FUTURE
LR_SCHEDULERS_CLS_DICT = {
"step": StepLRCallback,
"poly": PolyLRCallback,
"cosine": CosineLRCallback,
"exp": ExponentialLRCallback,
"function": FunctionLRCallback,
}
LR_WARMUP_CLS_DICT = {"linear_step": WarmupLRCallback}
[docs]class TestLRCallback(PhaseCallback):
"""
Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
"""
def __init__(self, lr_placeholder):
super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
self.lr_placeholder = lr_placeholder
def __call__(self, context: PhaseContext):
self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])