Content-Length: 1060723 | pFad | http://github.com/fastmachinelearning/qonnx/commit/19c84eb7401fb4ff3ebb02edadf645c6a2fb37cf

A1 Merge pull request #180 from ebby-s/feature/arbprec_float_dtype · fastmachinelearning/qonnx@19c84eb · GitHub
Skip to content

Commit 19c84eb

Browse files
authored
Merge pull request #180 from ebby-s/feature/arbprec_float_dtype
Merged with float_quant, added exp bias, implemented FloatQuant.infer_node_datatype()
2 parents 52d8f98 + 1611229 commit 19c84eb

File tree

16 files changed

+1019
-293
lines changed

16 files changed

+1019
-293
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip
29-
pip install -e .[testing,qkeras]
29+
pip install -e .[testing,qkeras,brevitas]
3030
3131
- name: Run tests
3232
run: |
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
### <a name="FloatQuant"></a><a name="abs">**FloatQuant**</a>
2+
3+
Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
4+
Additionally, takes five floats as input, which define the scale, exponent bitwidth, mantissa bitwidth, maximum representable value and exponent bias of the quantization,
5+
all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to
6+
control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with
7+
the same shape as the input data implies per-element scaling.
8+
9+
*Special (symbolic) values:* Specialized floating point datatype behaviors such as supporting infinity, NaN and subnormals are specified by the attributes of the node to inform backends, but note that they do not affect the behavior of the `FloatQuant` operator. Instead, the `max_val` input is used to account for decreased representational range due
10+
to having to represent special cases.
11+
12+
*Why `max_val` is specified explicitly?* The maximum representable value is derived from a combination of exponent and mantissa bitwidths, but also how many encodings are reserved for
13+
special (symbolic) values. This makes it nontrivial to infer the maximum representable value. For instance, OCP E5M2 reserves three encodings for NaN, whereas E4M3 reserves only one.
14+
15+
*Integer quantization:* This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists.
16+
17+
#### Version
18+
19+
This operator is not part of the ONNX standard and is not currently versioned.
20+
21+
#### Attributes
22+
23+
<dl>
24+
<dt><tt>has_infinity</tt> : int (default is 0)</dt>
25+
<dd>Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>
26+
27+
<dt><tt>has_nan</tt> : int (default is 0)</dt>
28+
<dd>Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>
29+
30+
<dt><tt>has_subnormal</tt> : int (default is 1)</dt>
31+
<dd>Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>
32+
33+
<dt><tt>saturation</tt> : int (default is 1)</dt>
34+
<dd>Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>
35+
36+
<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
37+
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
38+
39+
</dl>
40+
41+
#### Inputs
42+
43+
<dl>
44+
<dt><tt>X</tt> : tensor(float32)</dt>
45+
<dd>input tensor to quantize</dd>
46+
<dt><tt>scale</tt> : tensor(float32)</dt>
47+
<dd>The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor</dd>
48+
<dt><tt>exponent_bitwidth</tt> : tensor(float32)</dt>
49+
<dd>The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
50+
<dt><tt>mantissa_bitwidth</tt> : tensor(float32)</dt>
51+
<dd>The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
52+
<dt><tt>exponent_bias</tt> : tensor(float32)</dt>
53+
<dd>The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
54+
<dt><tt>max_val</tt> : tensor(float32)</dt>
55+
<dd>Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. </dd>
56+
</dl>
57+
58+
59+
#### Outputs
60+
61+
<dl>
62+
<dt><tt>Y</tt> : tensor(float32)</dt>
63+
<dd>Output tensor</dd>
64+
</dl>
65+
66+
#### Examples
67+
```python
68+
def compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias):
69+
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
70+
max_mantissa = np.sum((
71+
2. ** np.arange(
72+
0,
73+
-1. * mantissa_bit_width - 1.,
74+
-1.
75+
)))
76+
max_val = max_mantissa * (2 ** max_exponent)
77+
return max_val
78+
79+
import numpy as np
80+
x = np.random.rand(100).astype(np.float32)
81+
scale = 1
82+
exponent_bitwidth = 4
83+
mantissa_bitwidth = 3
84+
exponent_bias = 0
85+
max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
86+
rounding_mode = 'ROUND'
87+
signed = True
88+
xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode)
89+
```
90+
91+
92+
#### Sample Implementation
93+
```python
94+
# see src/qonnx/custom_op/general/floatquant.py for up-to-date implementation
95+
def float_quant(
96+
X,
97+
scale,
98+
exponent_bitwidth,
99+
mantissa_bitwidth,
100+
exponent_bias,
101+
signed,
102+
max_val=None,
103+
has_inf=False,
104+
has_nan=False,
105+
has_subnormal=False,
106+
rounding_mode="ROUND",
107+
saturation=True
108+
):
109+
"""Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization"""
110+
def resolve_rounding_mode(mode_string):
111+
"""Resolve the rounding mode string to the corresponding numpy functions."""
112+
mode_string = mode_string.upper()
113+
if mode_string == "ROUND":
114+
return np.round
115+
elif mode_string == "CEIL":
116+
return np.ceil
117+
elif mode_string == "FLOOR":
118+
return np.floor
119+
else:
120+
raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
121+
# the comments are left to track the correspondence with the brevitas code
122+
# np version of brevitas function
123+
def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask):
124+
if has_inf:
125+
X[p_max_val_mask] = np.inf
126+
X[n_max_val_mask] = -np.inf
127+
elif has_nan:
128+
full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask)
129+
X[full_max_val_mask] = np.nan
130+
X[inf_mask] = np.nan
131+
else:
132+
raise RuntimeError(
133+
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
134+
)
135+
return X
136+
137+
# consistency check
138+
# if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed):
139+
# raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.")
140+
141+
# x = self.input_view_impl(x) # assuming input_view_impl is Identity
142+
143+
# the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy
144+
# internal_scale = float_internal_scale(
145+
# scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
146+
147+
X = X / scale
148+
149+
eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps
150+
fp_internal_scale_min = 1. - exponent_bias - mantissa_bitwidth
151+
152+
internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth
153+
internal_scale = np.maximum(internal_scale, fp_internal_scale_min) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min)
154+
internal_scale = np.exp2(internal_scale)
155+
156+
x_q = internal_scale * resolve_rounding_mode(rounding_mode)(X / internal_scale) # self.float_to_int_impl(x / internal_scale)
157+
158+
max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
159+
max_value = max_value if max_val is None else np.minimum(max_value, max_val)
160+
min_value = 0. if not signed else -max_value
161+
162+
# Compute masks
163+
inf_mask = np.isinf(x_q)
164+
p_max_val_mask = x_q > max_value
165+
n_max_val_mask = x_q < min_value
166+
167+
# first clamp everything to [min_value,max_value], basically the saturating case
168+
x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value)
169+
170+
if not saturation:
171+
x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask)
172+
173+
return x_q * scale #, self.saturating, self.inf_values, self.nan_values

docs/qonnx-custom-ops/quant_op.md renamed to docs/qonnx-custom-ops/intquant_op.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
### <a name="Quant"></a><a name="abs">**Quant**</a>
1+
### <a name="Quant"></a><a name="abs">**IntQuant**</a>
22

3-
Calculates the quantized values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
3+
Calculates the integer-quantized values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
44
Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization,
55
which may be scalars or tensors with number of dimensions equal to the input data tensor, for e.g. tensor-wise
66
or channel-wise quantization.
77
The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute
88
rounding_mode defines how quantized values are rounded.
99

10-
Note: This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
10+
Notes:
11+
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
12+
* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
1113

1214
#### Version
1315

@@ -48,7 +50,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
4850

4951
#### Examples
5052
<details>
51-
<summary>Quant</summary>
53+
<summary>IntQuant</summary>
5254

5355
```python
5456
from onnx import helper
@@ -65,7 +67,7 @@ rounding_mode = "ROUND"
6567

6668
# Create node
6769
node = helper.make_node(
68-
'Quant',
70+
'IntQuant',
6971
domain='finn.custom_op.general',
7072
inputs=['x', 'scale', 'zeropt', 'bitwidth'],
7173
outputs=['y'],
@@ -79,7 +81,7 @@ node = helper.make_node(
7981
output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
8082

8183
# Execute node and compare
82-
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_quant')
84+
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_intquant')
8385

8486
```
8587

@@ -89,7 +91,7 @@ expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='te
8991
#### Sample Implementation
9092

9193
<details>
92-
<summary>Quant</summary>
94+
<summary>IntQuant</summary>
9395

9496
```python
9597
# SPDX-License-Identifier: Apache-2.0
@@ -179,7 +181,7 @@ def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
179181
return value
180182

181183
def resolve_rounding_mode(mode_string):
182-
"""Resolve the rounding mode string of Quant and Trunc ops
184+
"""Resolve the rounding mode string of IntQuant and Trunc ops
183185
to the corresponding numpy functions."""
184186
if mode_string == "ROUND":
185187
return np.round

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ testing =
8080
pytest-xdist
8181
pytest-cov
8282
pytest-randomly
83+
hypothesis
84+
mock
85+
86+
brevitas =
87+
brevitas>=0.11.0
8388

8489
notebooks =
8590
jupyter

src/qonnx/core/datatype.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,15 @@ def get_canonical_name(self):
146146

147147

148148
class ArbPrecFloatType(BaseDataType):
149-
def __init__(self, exponent_bits, mantissa_bits):
149+
def __init__(self, exponent_bits, mantissa_bits, exponent_bias=None):
150150
self._exponent_bits = exponent_bits
151151
self._mantissa_bits = mantissa_bits
152152

153+
if not exponent_bias:
154+
# default (IEEE-style) exponent bias
155+
exponent_bias = (2.0 ** (exponent_bits - 1)) - 1
156+
self._exponent_bias = exponent_bias
157+
153158
def signed(self):
154159
"Returns whether this DataType can represent negative numbers."
155160
return True
@@ -165,8 +170,7 @@ def mantissa_bits(self):
165170
return self._mantissa_bits
166171

167172
def exponent_bias(self):
168-
# default (IEEE-style) exponent bias
169-
return (2.0 ** (self.exponent_bits() - 1)) - 1
173+
return self._exponent_bias
170174

171175
def min(self):
172176
return -1 * self.max()
@@ -182,27 +186,32 @@ def max(self):
182186
return max_val
183187

184188
def allowed(self, value):
185-
# extract fields from fp32 representation
189+
# fp32 format parameters
186190
fp32_exponent_bias = 127
187191
fp32_mantissa_bitwidth = 23
192+
fp32_nrm_mantissa_bitwidth = fp32_mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
193+
# minifloat format parameters
194+
exponent_bias = self.exponent_bias()
195+
min_exponent = -exponent_bias + 1 # minimum exponent if IEEE-style denormals are supported
196+
mantissa_bitwidth = self.mantissa_bits()
197+
nrm_mantissa_bitwidth = mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
198+
# extract fields from fp32 representation
188199
bin_val = np.float32(value).view(np.uint32)
189200
exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth
190201
mant = bin_val & 0b00000000011111111111111111111111
191-
exponent_bias = self.exponent_bias()
192-
exponent_bitwidth = self.exponent_bits()
193-
mantissa_bitwidth = self.mantissa_bits()
194-
max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias
195-
min_exponent = -exponent_bias
202+
exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
203+
mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1
196204
# for this value to be representable as this ArbPrecFloatType:
197-
# the exponent must be within the representable range
198-
actual_exp = exp - fp32_exponent_bias
199-
exponent_ok = (min_exponent <= actual_exp) and (actual_exp <= max_exponent)
205+
# the value must be within the representable range
206+
range_ok = (value <= self.max()) and (value >= self.min())
200207
# the mantissa must be within representable range:
201-
# no set bits in the mantissa beyond the allowed number of bits
202-
# (computed by a mask here)
203-
mantissa_mask = "0" * mantissa_bitwidth + "1" * (fp32_mantissa_bitwidth - mantissa_bitwidth)
204-
mantissa_ok = (mant & int(mantissa_mask, base=2)) == 0
205-
return mantissa_ok and exponent_ok
208+
# no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
209+
# compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
210+
dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth))
211+
available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
212+
mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits)
213+
mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0
214+
return bool(mantissa_ok and range_ok)
206215

207216
def is_integer(self):
208217
return False
@@ -217,7 +226,7 @@ def to_numpy_dt(self):
217226
return np.float32
218227

219228
def get_canonical_name(self):
220-
return "FLOAT<%d,%d>" % (self.exponent_bits(), self.mantissa_bits())
229+
return "FLOAT<%d,%d,%d>" % (self.exponent_bits(), self.mantissa_bits(), self.exponent_bias())
221230

222231
def get_num_possible_values(self):
223232
# TODO: consider -0 and +0 as different values?
@@ -488,9 +497,17 @@ def resolve_datatype(name):
488497
name = name.replace("FLOAT<", "")
489498
name = name.replace(">", "")
490499
nums = name.split(",")
491-
exp_bits = int(nums[0].strip())
492-
mant_bits = int(nums[1].strip())
493-
return ArbPrecFloatType(exp_bits, mant_bits)
500+
if len(nums) == 2:
501+
exp_bits = int(nums[0].strip())
502+
mant_bits = int(nums[1].strip())
503+
return ArbPrecFloatType(exp_bits, mant_bits)
504+
elif len(nums) == 3:
505+
exp_bits = int(nums[0].strip())
506+
mant_bits = int(nums[1].strip())
507+
exp_bias = int(nums[2].strip())
508+
return ArbPrecFloatType(exp_bits, mant_bits, exp_bias)
509+
else:
510+
raise KeyError("Could not resolve DataType " + name)
494511
else:
495512
raise KeyError("Could not resolve DataType " + name)
496513

src/qonnx/custom_op/general/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828

2929
from qonnx.custom_op.general.bipolar_quant import BipolarQuant
3030
from qonnx.custom_op.general.debugmarker import DebugMarker
31+
from qonnx.custom_op.general.floatquant import FloatQuant
3132
from qonnx.custom_op.general.genericpartition import GenericPartition
3233
from qonnx.custom_op.general.im2col import Im2Col
34+
from qonnx.custom_op.general.intquant import IntQuant
3335
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
3436
from qonnx.custom_op.general.multithreshold import MultiThreshold
35-
from qonnx.custom_op.general.quant import Quant
3637
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
3738
from qonnx.custom_op.general.trunc import Trunc
3839
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
@@ -46,6 +47,8 @@
4647
custom_op["MultiThreshold"] = MultiThreshold
4748
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
4849
custom_op["Im2Col"] = Im2Col
49-
custom_op["Quant"] = Quant
50+
custom_op["IntQuant"] = IntQuant
51+
custom_op["Quant"] = IntQuant
5052
custom_op["Trunc"] = Trunc
5153
custom_op["BipolarQuant"] = BipolarQuant
54+
custom_op["FloatQuant"] = FloatQuant

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/fastmachinelearning/qonnx/commit/19c84eb7401fb4ff3ebb02edadf645c6a2fb37cf

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy