import torch from torch import nn from torch.quantization import quantize_dynamic
class DemoModel(nn.Module): def __init__(self): super(DemoModel, self).__init__() self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1) self.relu = nn.ReLU() self.fc = nn.Linear(2, 2)
def forward(self, x): x = self.conv(x) x = self.relu(x) x = self.fc(x) return x
model_fp32 = DemoModel() model_int8 = quantize_dynamic( model=model_fp32, qconfig_spec={nn.Linear}, # 仅对Linear层进行量化 dtype=torch.qint8 )