# -*- coding: utf-8 -*-
# file: resnet.py
# time: 14:43 29/01/2025
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# Homepage: https://yangheng95.github.io
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
# Adapted from: https://github.com/terry-r123/RNABenchmark/blob/main/downstream/structure/resnet.py
"""
ResNet implementation for genomic sequence analysis.
This module provides a ResNet architecture adapted for processing genomic sequences
and their structural representations. It includes basic blocks, bottleneck blocks,
and a complete ResNet implementation optimized for genomic data.
"""
from torch import Tensor
import torch.nn as nn
from typing import Type, Callable, Union, List, Optional
[docs]
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""
3x3 convolution with padding.
Args:
in_planes (int): Number of input channels
out_planes (int): Number of output channels
stride (int): Stride for the convolution (default: 1)
groups (int): Number of groups for grouped convolution (default: 1)
dilation (int): Dilation factor for the convolution (default: 1)
Returns:
nn.Conv2d: 3x3 convolution layer
"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
[docs]
def conv1x1(in_planes, out_planes, stride=1):
"""
1x1 convolution.
Args:
in_planes (int): Number of input channels
out_planes (int): Number of output channels
stride (int): Stride for the convolution (default: 1)
Returns:
nn.Conv2d: 1x1 convolution layer
"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
[docs]
def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""
5x5 convolution with padding.
Args:
in_planes (int): Number of input channels
out_planes (int): Number of output channels
stride (int): Stride for the convolution (default: 1)
groups (int): Number of groups for grouped convolution (default: 1)
dilation (int): Dilation factor for the convolution (default: 1)
Returns:
nn.Conv2d: 5x5 convolution layer
"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=5,
stride=stride,
padding=2,
groups=groups,
bias=False,
dilation=dilation,
)
[docs]
class BasicBlock(nn.Module):
"""
This block implements a basic residual connection with two convolutions
and is optimized for processing genomic sequence data with layer normalization.
Attributes:
expansion (int): Expansion factor for the block (default: 1)
conv1: First 3x3 convolution layer
bn1: First layer normalization
conv2: Second 5x5 convolution layer
bn2: Second layer normalization
relu: ReLU activation function
drop: Dropout layer
downsample: Downsampling layer for residual connection
stride: Stride for the convolutions
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample=None,
groups: int = 1,
# base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
"""
Initialize the BasicBlock.
Args:
inplanes (int): Number of input channels
planes (int): Number of output channels
stride (int): Stride for the convolutions (default: 1)
downsample: Downsampling layer for residual connection (default: None)
groups (int): Number of groups for grouped convolution (default: 1)
dilation (int): Dilation factor for convolutions (default: 1)
norm_layer: Normalization layer type (default: None, uses LayerNorm)
Raises:
NotImplementedError: If dilation > 1 is specified
"""
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.LayerNorm
# if groups != 1 or base_width != 64:
# raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=False)
self.drop = nn.Dropout(0.25, inplace=False)
self.conv2 = conv5x5(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
[docs]
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the BasicBlock.
Args:
x (Tensor): Input tensor [batch_size, channels, height, width]
Returns:
Tensor: Output tensor with same shape as input
"""
identity = x
x = x.permute(0, 2, 3, 1)
out = self.bn1(x)
out = out.permute(0, 3, 1, 2)
out = self.relu(out)
out = self.drop(out)
out = self.conv1(out)
out = out.permute(0, 2, 3, 1)
out = self.bn2(out)
out = out.permute(0, 3, 1, 2)
out = self.relu(out)
out = self.drop(out)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
return out
[docs]
class Bottleneck(nn.Module):
"""
This block implements a bottleneck residual connection with three convolutions
(1x1, 3x3, 1x1) and is designed for deeper networks. It's adapted from
the original ResNet V1.5 implementation.
Attributes:
expansion (int): Expansion factor for the block (default: 4)
conv1: First 1x1 convolution layer
bn1: First batch normalization
conv2: Second 3x3 convolution layer
bn2: Second batch normalization
conv3: Third 1x1 convolution layer
bn3: Third batch normalization
relu: ReLU activation function
downsample: Downsampling layer for residual connection
stride: Stride for the convolutions
"""
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
"""
Initialize the Bottleneck block.
Args:
inplanes (int): Number of input channels
planes (int): Number of output channels
stride (int): Stride for the convolutions (default: 1)
downsample: Downsampling layer for residual connection (default: None)
groups (int): Number of groups for grouped convolution (default: 1)
base_width (int): Base width for the bottleneck (default: 64)
dilation (int): Dilation factor for convolutions (default: 1)
norm_layer: Normalization layer type (default: None, uses BatchNorm2d)
"""
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
self.stride = stride
[docs]
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the Bottleneck block.
Args:
x (Tensor): Input tensor [batch_size, channels, height, width]
Returns:
Tensor: Output tensor with same shape as input
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.relu(out)
return out
[docs]
class ResNet(nn.Module):
"""
This ResNet implementation is specifically designed for processing genomic
sequences and their structural representations. It uses layer normalization
instead of batch normalization and is optimized for genomic data characteristics.
Attributes:
_norm_layer: Normalization layer type
inplanes: Number of input channels for the first layer
dilation: Dilation factor for convolutions
groups: Number of groups for grouped convolutions
base_width: Base width for bottleneck blocks
conv1: Initial convolution layer
bn1: Initial normalization layer
relu: ReLU activation function
layer1: First layer of ResNet blocks
fc1: Final fully connected layer
"""
def __init__(
self,
channels,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 1,
replace_stride_with_dilation=None,
norm_layer=None,
) -> None:
"""
Initialize the ResNet architecture.
Args:
channels (int): Number of input channels
block: Type of ResNet block (BasicBlock or Bottleneck)
layers (List[int]): List specifying the number of blocks in each layer
zero_init_residual (bool): Whether to zero-initialize residual connections (default: False)
groups (int): Number of groups for grouped convolutions (default: 1)
width_per_group (int): Width per group for bottleneck blocks (default: 1)
replace_stride_with_dilation: Whether to replace stride with dilation (default: None)
norm_layer: Normalization layer type (default: None, uses LayerNorm)
Raises:
ValueError: If replace_stride_with_dilation is not None or a 3-element tuple
"""
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.LayerNorm
self._norm_layer = norm_layer
self.inplanes = 48
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
channels, self.inplanes, kernel_size=3, stride=1, padding=1
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=False)
self.layer1 = self._make_layer(block, 48, layers[0])
self.fc1 = nn.Linear(48, 1)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
"""
Create a layer of ResNet blocks.
Args:
block: Type of ResNet block to use
planes (int): Number of output channels for the layer
blocks (int): Number of blocks in the layer
stride (int): Stride for the first block (default: 1)
dilate (bool): Whether to use dilation (default: False)
Returns:
nn.Sequential: Sequential container of ResNet blocks
"""
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
self.groups,
self.base_width,
previous_dilation,
norm_layer,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
"""
Forward pass implementation.
Args:
x (Tensor): Input tensor [batch_size, channels, height, width]
Returns:
Tensor: Output tensor after processing through ResNet
"""
# [bz,hd,len,len]
x = self.conv1(x)
x = x.permute(0, 2, 3, 1)
x = self.bn1(x)
x = x.permute(0, 3, 1, 2)
x = self.relu(x)
x = self.layer1(x)
x = x.mean(dim=[2, 3])
x = self.fc1(x)
return x
[docs]
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the ResNet.
Args:
x (Tensor): Input tensor [batch_size, channels, height, width]
Returns:
Tensor: Output tensor after processing through ResNet
"""
return self._forward_impl(x)
[docs]
def resnet_b16(channels=128, bbn=16):
"""
This function creates a ResNet model with 16 basic blocks, optimized
for processing genomic sequences and their structural representations.
Args:
channels (int): Number of input channels (default: 128)
bbn (int): Number of basic blocks (default: 16)
Returns:
ResNet: Configured ResNet model
"""
return ResNet(channels, BasicBlock, [bbn])