nlp_architect.nn.torch package¶
Subpackages¶
Submodules¶
nlp_architect.nn.torch.distillation module¶
-
class
nlp_architect.nn.torch.distillation.
TeacherStudentDistill
(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, kd_w: float = 0.5, loss_w: float = 0.5)[source]¶ Bases:
object
Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- kd_w (float, optional) – teacher loss weight. Defaults to 0.5.
- loss_w (float, optional) – student loss weight. Defaults to 0.5.
-
static
add_args
(parser: argparse.ArgumentParser)[source]¶ Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
nlp_architect.nn.torch.quantization module¶
Quantization ops
-
class
nlp_architect.nn.torch.quantization.
FakeLinearQuantizationWithSTE
[source]¶ Bases:
torch.autograd.function.Function
Simulates error caused by quantization. Uses Straight-Through Estimator for Back prop
-
static
backward
(ctx, grad_output)[source]¶ Calculate estimated gradients for fake quantization using Straigh-Through Estimator (STE) according to: https://openreview.net/pdf?id=B1ae1lZRb
-
static
-
class
nlp_architect.nn.torch.quantization.
QuantizationConfig
(activation_bits=8, weight_bits=8, mode='none', start_step=0, ema_decay=0.9999, requantize_output=True)[source]¶ Bases:
object
Quantization Configuration Object
-
class
nlp_architect.nn.torch.quantization.
QuantizationMode
[source]¶ Bases:
enum.Enum
An enumeration.
-
DYNAMIC
= 2¶
-
EMA
= 3¶
-
NONE
= 1¶
-
-
class
nlp_architect.nn.torch.quantization.
QuantizedEmbedding
(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source]¶ Bases:
torch.nn.modules.sparse.Embedding
Embedding layer with quantization aware training capability
-
extra_repr
()[source]¶ Set the extra representation of the module
To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.
-
fake_quantized_weight
¶
-
forward
(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
classmethod
from_config
(*args, config=None, **kwargs)[source]¶ Initialize quantized layer from config
-
weight_scale
¶
-
-
class
nlp_architect.nn.torch.quantization.
QuantizedLinear
(*args, activation_bits=8, weight_bits=8, requantize_output=True, ema_decay=0.9999, start_step=0, mode='none', **kwargs)[source]¶ Bases:
torch.nn.modules.linear.Linear
Linear layer with quantization aware training capability
-
extra_repr
()[source]¶ Set the extra representation of the module
To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.
-
fake_quantized_weight
¶
-
forward
(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
classmethod
from_config
(*args, config=None, **kwargs)[source]¶ Initialize quantized layer from config
-
inference_quantized_forward
(input)[source]¶ Simulate quantized inference. quantize input and perform calculation with only integer numbers. This function should only be used while doing inference
-
quantized_weight
¶
-
training_quantized_forward
(input)[source]¶ fake quantized forward, fake quantizes weights and activations, learn quantization ranges if quantization mode is EMA. This function should only be used while training
-
weight_scale
¶
-
-
nlp_architect.nn.torch.quantization.
calc_max_quant_value
(bits)[source]¶ Calculate the maximum symmetric quantized value according to number of bits
-
nlp_architect.nn.torch.quantization.
dequantize
(input, scale)[source]¶ linear dequantization according to some scale
-
nlp_architect.nn.torch.quantization.
get_dynamic_scale
(x, bits, with_grad=False)[source]¶ Calculate dynamic scale for quantization from input by taking the maximum absoulute value from x and number of bits