78 lines
2.5 KiB
Python
Executable File
78 lines
2.5 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from collections import abc
|
|
|
|
|
|
def cast_tensor_type(inputs, src_type, dst_type):
|
|
"""Recursively convert Tensor in inputs from src_type to dst_type.
|
|
|
|
Args:
|
|
inputs: Inputs that to be casted.
|
|
src_type (torch.dtype): Source type..
|
|
dst_type (torch.dtype): Destination type.
|
|
|
|
Returns:
|
|
The same type with inputs, but all contained Tensors have been cast.
|
|
"""
|
|
if isinstance(inputs, nn.Module):
|
|
return inputs
|
|
elif isinstance(inputs, torch.Tensor):
|
|
return inputs.to(dst_type)
|
|
elif isinstance(inputs, str):
|
|
return inputs
|
|
elif isinstance(inputs, np.ndarray):
|
|
return inputs
|
|
elif isinstance(inputs, abc.Mapping):
|
|
return type(inputs)({
|
|
k: cast_tensor_type(v, src_type, dst_type)
|
|
for k, v in inputs.items()
|
|
})
|
|
elif isinstance(inputs, abc.Iterable):
|
|
return type(inputs)(
|
|
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
|
else:
|
|
return inputs
|
|
|
|
|
|
def patch_forward_method(func, src_type, dst_type, convert_output=True):
|
|
"""Patch the forward method of a module.
|
|
|
|
Args:
|
|
func (callable): The original forward method.
|
|
src_type (torch.dtype): Type of input arguments to be converted from.
|
|
dst_type (torch.dtype): Type of input arguments to be converted to.
|
|
convert_output (bool): Whether to convert the output back to src_type.
|
|
|
|
Returns:
|
|
callable: The patched forward method.
|
|
"""
|
|
|
|
def new_forward(*args, **kwargs):
|
|
output = func(*cast_tensor_type(args, src_type, dst_type),
|
|
**cast_tensor_type(kwargs, src_type, dst_type))
|
|
if convert_output:
|
|
output = cast_tensor_type(output, dst_type, src_type)
|
|
return output
|
|
|
|
return new_forward
|
|
|
|
|
|
def patch_norm_fp32(module):
|
|
"""Recursively convert normalization layers from FP16 to FP32.
|
|
|
|
Args:
|
|
module (nn.Module): The modules to be converted in FP16.
|
|
|
|
Returns:
|
|
nn.Module: The converted module, the normalization layers have been
|
|
converted to FP32.
|
|
"""
|
|
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
|
|
module.float()
|
|
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
|
|
module.forward = patch_forward_method(module.forward, torch.half,
|
|
torch.float)
|
|
for child in module.children():
|
|
patch_norm_fp32(child)
|
|
return module |