pytorch量化中torch.quantize_per_tensor()函数参数详解

torch.quantize_per_tensor(input,scale, zero_point, dtype)实现8位量化:

摘要:对该函数各个参数的分析

量化:

计算机运算时,默认32位浮点数,若将32位浮点数,变成8位定点数,会快很多。
目前pytorch中的反向传播不支持量化,所以该量化只用于评估训练好的模型,或者将32位浮点数模型存储为8位定点数模型,读取8位定点数模型后需要转换为32位浮点数才能进行神经网络参数的训练。

量化函数原型:Q = torch.quantize_per_tensor(input,scale = 0.025 , zero_point = 0, dtype = torch.quint8)

**

  1. input为准备量化的32位浮点数,Q为量化后的8位定点数
  2. dtype为量化类型,quint8代表8位无符号数,qint8代表8位带符号数,最高位是符号位
  3. 假设量化为qint8,设量化后的数Q为0001_1101,最高位为0(符号位),所以是正数;后7位转换为10进制是29,所以Q代表的数为 :zero_point + Q * scale = 0 + 29 * 0.025 = 0.725
  4. 所以最终使用print显示Q时,显示的不是0001_1101而是0.725,但它在计算机中存储时,是0001_1101
  5. 使用dequantize()可以解除量化
  6. 量化公式为:pytorch量化中torch.quantize_per_tensor()函数参数详解

**

代码及其运行结果:

pytorch量化中torch.quantize_per_tensor()函数参数详解

总结:

以zero_point为中心,用8位数Q代表input离中心有多远,scale为距离单位
即input ≈ zero_point + Q * scale