import torch import torch.nn as nn import numpy as np from collections import abc def cast_tensor_type(inputs, src_type, dst_type): """Recursively convert Tensor in inputs from src_type to dst_type. Args: inputs: Inputs that to be casted. src_type (torch.dtype): Source type.. dst_type (torch.dtype): Destination type. Returns: The same type with inputs, but all contained Tensors have been cast. """ if isinstance(inputs, nn.Module): return inputs elif isinstance(inputs, torch.Tensor): return inputs.to(dst_type) elif isinstance(inputs, str): return inputs elif isinstance(inputs, np.ndarray): return inputs elif isinstance(inputs, abc.Mapping): return type(inputs)({ k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items() }) elif isinstance(inputs, abc.Iterable): return type(inputs)( cast_tensor_type(item, src_type, dst_type) for item in inputs) else: return inputs def patch_forward_method(func, src_type, dst_type, convert_output=True): """Patch the forward method of a module. Args: func (callable): The original forward method. src_type (torch.dtype): Type of input arguments to be converted from. dst_type (torch.dtype): Type of input arguments to be converted to. convert_output (bool): Whether to convert the output back to src_type. Returns: callable: The patched forward method. """ def new_forward(*args, **kwargs): output = func(*cast_tensor_type(args, src_type, dst_type), **cast_tensor_type(kwargs, src_type, dst_type)) if convert_output: output = cast_tensor_type(output, dst_type, src_type) return output return new_forward def patch_norm_fp32(module): """Recursively convert normalization layers from FP16 to FP32. Args: module (nn.Module): The modules to be converted in FP16. Returns: nn.Module: The converted module, the normalization layers have been converted to FP32. """ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): module.float() if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': module.forward = patch_forward_method(module.forward, torch.half, torch.float) for child in module.children(): patch_norm_fp32(child) return module