sbt-idp/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtr/sdsvtr/fp16_utils.py
2023-12-12 18:51:32 +07:00

78 lines
2.5 KiB
Python
Executable File

import torch
import torch.nn as nn
import numpy as np
from collections import abc
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def patch_forward_method(func, src_type, dst_type, convert_output=True):
"""Patch the forward method of a module.
Args:
func (callable): The original forward method.
src_type (torch.dtype): Type of input arguments to be converted from.
dst_type (torch.dtype): Type of input arguments to be converted to.
convert_output (bool): Whether to convert the output back to src_type.
Returns:
callable: The patched forward method.
"""
def new_forward(*args, **kwargs):
output = func(*cast_tensor_type(args, src_type, dst_type),
**cast_tensor_type(kwargs, src_type, dst_type))
if convert_output:
output = cast_tensor_type(output, dst_type, src_type)
return output
return new_forward
def patch_norm_fp32(module):
"""Recursively convert normalization layers from FP16 to FP32.
Args:
module (nn.Module): The modules to be converted in FP16.
Returns:
nn.Module: The converted module, the normalization layers have been
converted to FP32.
"""
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
module.float()
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
module.forward = patch_forward_method(module.forward, torch.half,
torch.float)
for child in module.children():
patch_norm_fp32(child)
return module