|
| 1 | +### <a name="FloatQuant"></a><a name="abs">**FloatQuant**</a> |
| 2 | + |
| 3 | +Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor<T>) and produces one output data (Tensor<T>). |
| 4 | +Additionally, takes five floats as input, which define the scale, exponent bitwidth, mantissa bitwidth, maximum representable value and exponent bias of the quantization, |
| 5 | +all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to |
| 6 | +control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with |
| 7 | +the same shape as the input data implies per-element scaling. |
| 8 | + |
| 9 | +*Special (symbolic) values:* Specialized floating point datatype behaviors such as supporting infinity, NaN and subnormals are specified by the attributes of the node to inform backends, but note that they do not affect the behavior of the `FloatQuant` operator. Instead, the `max_val` input is used to account for decreased representational range due |
| 10 | +to having to represent special cases. |
| 11 | + |
| 12 | +*Why `max_val` is specified explicitly?* The maximum representable value is derived from a combination of exponent and mantissa bitwidths, but also how many encodings are reserved for |
| 13 | +special (symbolic) values. This makes it nontrivial to infer the maximum representable value. For instance, OCP E5M2 reserves three encodings for NaN, whereas E4M3 reserves only one. |
| 14 | + |
| 15 | +*Integer quantization:* This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists. |
| 16 | + |
| 17 | +#### Version |
| 18 | + |
| 19 | +This operator is not part of the ONNX standard and is not currently versioned. |
| 20 | + |
| 21 | +#### Attributes |
| 22 | + |
| 23 | +<dl> |
| 24 | +<dt><tt>has_infinity</tt> : int (default is 0)</dt> |
| 25 | +<dd>Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd> |
| 26 | + |
| 27 | +<dt><tt>has_nan</tt> : int (default is 0)</dt> |
| 28 | +<dd>Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd> |
| 29 | + |
| 30 | +<dt><tt>has_subnormal</tt> : int (default is 1)</dt> |
| 31 | +<dd>Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd> |
| 32 | + |
| 33 | +<dt><tt>saturation</tt> : int (default is 1)</dt> |
| 34 | +<dd>Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd> |
| 35 | + |
| 36 | +<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt> |
| 37 | +<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd> |
| 38 | + |
| 39 | +</dl> |
| 40 | + |
| 41 | +#### Inputs |
| 42 | + |
| 43 | +<dl> |
| 44 | +<dt><tt>X</tt> : tensor(float32)</dt> |
| 45 | +<dd>input tensor to quantize</dd> |
| 46 | +<dt><tt>scale</tt> : tensor(float32)</dt> |
| 47 | +<dd>The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor</dd> |
| 48 | +<dt><tt>exponent_bitwidth</tt> : tensor(float32)</dt> |
| 49 | +<dd>The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd> |
| 50 | +<dt><tt>mantissa_bitwidth</tt> : tensor(float32)</dt> |
| 51 | +<dd>The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd> |
| 52 | +<dt><tt>exponent_bias</tt> : tensor(float32)</dt> |
| 53 | +<dd>The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd> |
| 54 | +<dt><tt>max_val</tt> : tensor(float32)</dt> |
| 55 | +<dd>Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. </dd> |
| 56 | +</dl> |
| 57 | + |
| 58 | + |
| 59 | +#### Outputs |
| 60 | + |
| 61 | +<dl> |
| 62 | +<dt><tt>Y</tt> : tensor(float32)</dt> |
| 63 | +<dd>Output tensor</dd> |
| 64 | +</dl> |
| 65 | + |
| 66 | +#### Examples |
| 67 | +```python |
| 68 | +def compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias): |
| 69 | + max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias |
| 70 | + max_mantissa = np.sum(( |
| 71 | + 2. ** np.arange( |
| 72 | + 0, |
| 73 | + -1. * mantissa_bit_width - 1., |
| 74 | + -1. |
| 75 | + ))) |
| 76 | + max_val = max_mantissa * (2 ** max_exponent) |
| 77 | + return max_val |
| 78 | + |
| 79 | +import numpy as np |
| 80 | +x = np.random.rand(100).astype(np.float32) |
| 81 | +scale = 1 |
| 82 | +exponent_bitwidth = 4 |
| 83 | +mantissa_bitwidth = 3 |
| 84 | +exponent_bias = 0 |
| 85 | +max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) |
| 86 | +rounding_mode = 'ROUND' |
| 87 | +signed = True |
| 88 | +xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode) |
| 89 | +``` |
| 90 | + |
| 91 | + |
| 92 | +#### Sample Implementation |
| 93 | +```python |
| 94 | +# see src/qonnx/custom_op/general/floatquant.py for up-to-date implementation |
| 95 | +def float_quant( |
| 96 | + X, |
| 97 | + scale, |
| 98 | + exponent_bitwidth, |
| 99 | + mantissa_bitwidth, |
| 100 | + exponent_bias, |
| 101 | + signed, |
| 102 | + max_val=None, |
| 103 | + has_inf=False, |
| 104 | + has_nan=False, |
| 105 | + has_subnormal=False, |
| 106 | + rounding_mode="ROUND", |
| 107 | + saturation=True |
| 108 | +): |
| 109 | + """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" |
| 110 | + def resolve_rounding_mode(mode_string): |
| 111 | + """Resolve the rounding mode string to the corresponding numpy functions.""" |
| 112 | + mode_string = mode_string.upper() |
| 113 | + if mode_string == "ROUND": |
| 114 | + return np.round |
| 115 | + elif mode_string == "CEIL": |
| 116 | + return np.ceil |
| 117 | + elif mode_string == "FLOOR": |
| 118 | + return np.floor |
| 119 | + else: |
| 120 | + raise ValueError(f"Could not resolve rounding mode called: {mode_string}") |
| 121 | + # the comments are left to track the correspondence with the brevitas code |
| 122 | + # np version of brevitas function |
| 123 | + def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): |
| 124 | + if has_inf: |
| 125 | + X[p_max_val_mask] = np.inf |
| 126 | + X[n_max_val_mask] = -np.inf |
| 127 | + elif has_nan: |
| 128 | + full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask) |
| 129 | + X[full_max_val_mask] = np.nan |
| 130 | + X[inf_mask] = np.nan |
| 131 | + else: |
| 132 | + raise RuntimeError( |
| 133 | + "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" |
| 134 | + ) |
| 135 | + return X |
| 136 | + |
| 137 | + # consistency check |
| 138 | + # if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed): |
| 139 | + # raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.") |
| 140 | + |
| 141 | + # x = self.input_view_impl(x) # assuming input_view_impl is Identity |
| 142 | + |
| 143 | + # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy |
| 144 | + # internal_scale = float_internal_scale( |
| 145 | + # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) |
| 146 | + |
| 147 | + X = X / scale |
| 148 | + |
| 149 | + eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps |
| 150 | + fp_internal_scale_min = 1. - exponent_bias - mantissa_bitwidth |
| 151 | + |
| 152 | + internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth |
| 153 | + internal_scale = np.maximum(internal_scale, fp_internal_scale_min) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min) |
| 154 | + internal_scale = np.exp2(internal_scale) |
| 155 | + |
| 156 | + x_q = internal_scale * resolve_rounding_mode(rounding_mode)(X / internal_scale) # self.float_to_int_impl(x / internal_scale) |
| 157 | + |
| 158 | + max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) |
| 159 | + max_value = max_value if max_val is None else np.minimum(max_value, max_val) |
| 160 | + min_value = 0. if not signed else -max_value |
| 161 | + |
| 162 | + # Compute masks |
| 163 | + inf_mask = np.isinf(x_q) |
| 164 | + p_max_val_mask = x_q > max_value |
| 165 | + n_max_val_mask = x_q < min_value |
| 166 | + |
| 167 | + # first clamp everything to [min_value,max_value], basically the saturating case |
| 168 | + x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value) |
| 169 | + |
| 170 | + if not saturation: |
| 171 | + x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask) |
| 172 | + |
| 173 | + return x_q * scale #, self.saturating, self.inf_values, self.nan_values |
0 commit comments