Skip to content

Commit

Permalink
support LightCDNet
Browse files Browse the repository at this point in the history
  • Loading branch information
likyoo committed Feb 10, 2024
1 parent 6ef61e4 commit 09c03eb
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Supported change detection model:
- [x] [Changer (TGRS'2023)](configs/changer)
- [x] [HANet (JSTARS'2023)](configs/hanet)
- [x] [TinyCDv2 (Under Review)](configs/tinycd_v2)
- [x] [LightCDNet (GRSL'2023)](configs/lightcdnet)
- [x] [BAN (arXiv'2023)](configs/ban)
- [x] [TTP (arXiv'2023)](configs/ttp)
- [ ] ...
Expand Down
54 changes: 54 additions & 0 deletions configs/_base_/models/lightcdnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='DualInputSegDataPreProcessor',
mean=[123.675, 116.28, 103.53] * 2,
std=[58.395, 57.12, 57.375] * 2,
bgr_to_rgb=True,
size_divisor=32,
pad_val=0,
seg_pad_val=255,
test_cfg=dict(size_divisor=32))
model = dict(
type='DIEncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='LightCDNet',
stage_repeat_num=[4, 8, 4],
net_type="small"),
neck=dict(
type='TinyFPN',
exist_early_x=True,
early_x_for_fpn=True,
custom_block='conv',
in_channels=[24, 48, 96, 192],
out_channels=48,
num_outs=4),
decode_head=dict(
type='DS_FPNHead',
in_channels=[48, 48, 48, 48],
in_index=[0, 1, 2, 3],
channels=48,
dropout_ratio=0.,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='mmseg.FCNHead',
in_channels=24,
in_index=0,
channels=24,
num_convs=1,
concat_input=False,
dropout_ratio=0.,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
45 changes: 45 additions & 0 deletions configs/lightcdnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# LightCDNet

[LightCDNet: Lightweight Change Detection Network Based on VHR Images](https://ieeexplore.ieee.org/document/10214556)

## Introduction

[Official Repo](https://github.com/NightSongs/LightCDNet)

[Code Snippet](https://github.com/likyoo/open-cd/blob/main/opencd/models/backbones/lightcdnet.py)

## Abstract
Lightweight change detection models are essential for industrial applications and edge devices. Reducing the model size while maintaining high accuracy is a key challenge in developing lightweight change detection models. However, many existing methods oversimplify the model architecture, leading to a loss of information and reduced performance. Therefore, developing a lightweight model that can effectively preserve the input information is a challenging problem. To address this challenge, we propose LightCDNet, a novel lightweight change detection model that effectively preserves the input information. LightCDNet consists of an early fusion backbone network and a pyramid decoder for end-to-end change detection. The core component of LightCDNet is the Deep Supervised Fusion Module (DSFM), which guides the early fusion of primary features to improve performance. We evaluated LightCDNet on the LEVIR-CD dataset and found that it achieved comparable or better performance than state-of-the-art models while being 10–117 times smaller in size.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/likyoo/open-cd/assets/44317497/cec088ca-cb45-4d32-8ebb-c0fd3b8d1a4c" width="90%"/>
</div>


```bibtex
@ARTICLE{10214556,
author={Xing, Yuanjun and Jiang, Jiawei and Xiang, Jun and Yan, Enping and Song, Yabin and Mo, Dengkui},
journal={IEEE Geoscience and Remote Sensing Letters},
title={LightCDNet: Lightweight Change Detection Network Based on VHR Images},
year={2023},
volume={20},
number={},
pages={1-5},
doi={10.1109/LGRS.2023.3304309}}
```

## Results and models

### LEVIR-CD

| Method | Crop Size | Lr schd | \#Param (M) | MACs (G) | Precision | Recall | F1-Score | IoU | config |
| :--------------: | :-------: | :-----: | :---------: | :------: | :-------: | :----: | :------: | :---: | ------------------------------------------------------------ |
| LightCDNet-small | 256x256 | 40000 | 0.35 | 1.65 | 91.36 | 89.81 | 90.57 | 82.77 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_s_256x256_40k_levircd.py) |
| LightCDNet-base | 256x256 | 40000 | 1.32 | 3.22 | 92.12 | 90.43 | 91.27 | 83.94 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_b_256x256_40k_levircd.py) |
| LightCDNet-large | 256x256 | 40000 | 2.82 | 5.94 | 92.43 | 90.45 | 91.43 | 84.21 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_l_256x256_40k_levircd.py) |


- All metrics are based on the category "change".
- All scores are computed on the test set.
5 changes: 5 additions & 0 deletions configs/lightcdnet/lightcdnet_b_256x256_40k_levircd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['./lightcdnet_s_256x256_40k_levircd.py']

model = dict(
backbone=dict(net_type="base"),
neck=dict(in_channels=[24, 116, 232, 464]))
5 changes: 5 additions & 0 deletions configs/lightcdnet/lightcdnet_l_256x256_40k_levircd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['./lightcdnet_s_256x256_40k_levircd.py']

model = dict(
backbone=dict(net_type="large"),
neck=dict(in_channels=[24, 176, 352, 704]))
19 changes: 19 additions & 0 deletions configs/lightcdnet/lightcdnet_s_256x256_40k_levircd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_base_ = [
'../_base_/models/lightcdnet.py',
'../common/standard_256x256_40k_levircd.py']

model = dict(
decode_head=dict(
sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.7, min_kept=100000)))

# optimizer
optimizer = dict(
type='AdamW',
lr=0.003,
betas=(0.9, 0.999),
weight_decay=0.05)

optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer)
4 changes: 3 additions & 1 deletion opencd/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from .hanet import HAN
from .vit_tuner import VisionTransformerTurner
from .vit_sam import ViTSAM_Custom
from .lightcdnet import LightCDNet

__all__ = ['IA_ResNetV1c', 'IA_ResNeSt', 'FC_EF', 'FC_Siam_diff',
'FC_Siam_conc', 'SNUNet_ECAM', 'TinyCD', 'IFN',
'TinyNet', 'IA_MixVisionTransformer', 'HAN',
'VisionTransformerTurner', 'ViTSAM_Custom']
'VisionTransformerTurner', 'ViTSAM_Custom',
'LightCDNet']
227 changes: 227 additions & 0 deletions opencd/models/backbones/lightcdnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
import numpy as np
from mmcv.ops import CrissCrossAttention

from mmseg.models.utils import LayerNorm2d
from opencd.registry import MODELS


class CCA(nn.Module):
"""Criss-Cross Attention for Semantic Segmentation.
This head is the implementation of `CCNet
<https://arxiv.org/abs/1811.11721>`_.
Args:
recurrence (int): Number of recurrence of Criss Cross Attention
module. Default: 2.
"""

def __init__(self, channels, recurrence=2):
super(CCA, self).__init__()
self.recurrence = recurrence
self.cca = CrissCrossAttention(channels)

def forward(self, x):
for _ in range(self.recurrence):
x = self.cca(x)
return x


def channel_shuffle(x, groups=2):
bat_size, channels, w, h = x.shape
group_c = channels // groups
x = x.view(bat_size, groups, group_c, w, h)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(bat_size, -1, w, h)
return x


class ShuffleBlock(nn.Module):

def __init__(self, in_c, out_c, downsample=False):
super(ShuffleBlock, self).__init__()
self.downsample = downsample
half_c = out_c // 2
if downsample:
self.branch1 = nn.Sequential(
# 3*3 dw conv, stride = 2
nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False),
nn.BatchNorm2d(in_c),
# 1*1 pw conv
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))

self.branch2 = nn.Sequential(
# 1*1 pw conv
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True),
# 3*3 dw conv, stride = 2
nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False),
nn.BatchNorm2d(half_c),
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))

else:
assert in_c == out_c

self.branch2 = nn.Sequential(
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True),
# 3*3 dw conv, stride = 1
nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False),
nn.BatchNorm2d(half_c),
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))

def forward(self, x):
out = None
if self.downsample:
# if it is downsampling, we don't need to do channel split
out = torch.cat((self.branch1(x), self.branch2(x)), 1)
else:
# channel split
channels = x.shape[1]
c = channels // 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]
out = torch.cat((x1, self.branch2(x2)), 1)

return channel_shuffle(out, 2)


class TimeAttention(nn.Module):

def __init__(self, channels):
super(TimeAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
attn_channels = channels // 16
attn_channels = max(attn_channels, 8)
self.mlp = nn.Sequential(
nn.Conv2d(channels * 2, attn_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(attn_channels),
nn.ReLU(),
nn.Conv2d(attn_channels, channels * 2, kernel_size=1, bias=False),
)

def forward(self, x1, x2):
x = torch.cat((x1, x2), dim=1)
x = self.avg_pool(x)
y = self.mlp(x)
B, C, H, W = y.size()
x1_attn, x2_attn = y.reshape(B, 2, C // 2, H, W).transpose(0, 1)
x1_attn = torch.sigmoid(x1_attn)
x2_attn = torch.sigmoid(x2_attn)
x1 = x1 * x1_attn + x1
x2 = x2 * x2_attn + x2
return x1, x2


class shuffle_fusion(nn.Module):

def __init__(self, channels, block_num=2):
super().__init__()

self.stages = []
self.stages.append(
nn.Sequential(
nn.Conv2d(channels, channels * 4, kernel_size=1, bias=False),
nn.BatchNorm2d(channels * 4), nn.ReLU()))
for i in range(block_num):
self.stages.append(
ShuffleBlock(channels * 4, channels * 4, downsample=False))

self.stages = nn.Sequential(*self.stages)

self.single_conv = nn.Sequential(
nn.Conv2d(channels * 4, channels, kernel_size=1, bias=False),
nn.BatchNorm2d(channels), nn.ReLU())

self.time_attn = TimeAttention(channels)

self.final_conv = nn.Sequential(
nn.Conv2d(channels * 2, channels, kernel_size=1, bias=False),
nn.BatchNorm2d(channels), nn.ReLU())

def forward_single(self, x):
identity = x
x = self.stages(x)
x = self.single_conv(x)
x = identity + x
return x

def forward(self, x1, x2):
x1 = self.forward_single(x1)
x2 = self.forward_single(x2)
x1, x2 = self.time_attn(x1, x2)
x = self.final_conv(channel_shuffle(torch.cat((x1, x2), dim=1)))
return x


@MODELS.register_module()
class LightCDNet(nn.Module):

def __init__(self, stage_repeat_num, net_type="small"):
super(LightCDNet, self).__init__()

index_list = stage_repeat_num.copy()
index_list[0] = index_list[0] - 1
self.index_list = list(np.cumsum(index_list))
if net_type == "small":
self.out_channels = [24, 48, 96, 192]
self.block_num = 4
elif net_type == "base":
self.out_channels = [24, 116, 232, 464]
self.block_num = 8
elif net_type == "large":
self.out_channels = [24, 176, 352, 704]
self.block_num = 16
else:
print("the model type is error!")

self.conv1 = nn.Sequential(
nn.Conv2d(3, self.out_channels[0], 3, 2, 1, bias=False),
LayerNorm2d(self.out_channels[0]), nn.GELU())

self.fusion_conv = shuffle_fusion(
self.out_channels[0], block_num=self.block_num)

in_c = self.out_channels[0]

self.stages = []
for stage_idx in range(len(stage_repeat_num)):
out_c = self.out_channels[1 + stage_idx]
repeat_num = stage_repeat_num[stage_idx]
for i in range(repeat_num):
if i == 0:
self.stages.append(
ShuffleBlock(in_c, out_c, downsample=True))
else:
self.stages.append(
ShuffleBlock(in_c, in_c, downsample=False))
in_c = out_c
self.stages.append(CCA(channels=out_c, recurrence=2))

self.stages = nn.Sequential(*self.stages)

def forward(self, x1, x2):
x1 = self.conv1(x1)
x2 = self.conv1(x2)
x = self.fusion_conv(x1, x2)
outs = [x]

for i in range(len(self.stages)):
x = self.stages[i](x)
if i in self.index_list:
outs.append(x)
return outs
3 changes: 2 additions & 1 deletion opencd/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .ban_head import BitemporalAdapterHead
from .ban_utils import BAN_MLPDecoder, BAN_BITHead
from .mlpseg_head import MLPSegHead
from .ds_fpn_head import DS_FPNHead

__all__ = ['BITHead', 'Changer', 'IdentityHead', 'DSIdentityHead', 'TinyHead',
'STAHead', 'MultiHeadDecoder', 'GeneralSCDHead', 'BitemporalAdapterHead',
'BAN_MLPDecoder', 'BAN_BITHead', 'MLPSegHead']
'BAN_MLPDecoder', 'BAN_BITHead', 'MLPSegHead', 'DS_FPNHead']
Loading

0 comments on commit 09c03eb

Please sign in to comment.