Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformer base class #892

Merged
merged 7 commits into from
Mar 25, 2021
Merged

Conversation

jshilong
Copy link
Collaborator

@jshilong jshilong commented Mar 17, 2021

  • This PR is related to the basic operations of Vision Transformer.

    1. Design a unified data-flow for the transformer.

    2. Design a BaseTransformerLayer as base-class of TransformerEncoderlayer and TransformerDecoderlayer in the vision transformer. It contains several basic operations such as attention, layer norm, and FFN. It can be built from ConfigDict and supports customization, for example, you can specify any number of FFN or LN and use different kinds of attention by specifying a list of ConfigDict named attn_cfgs. It is worth mentioning that it supports prenorm.

    3. Design a TransformerLayerSequence as base-class of tranformerEncoder and tranformerDecoder) in the vision transformer. Support customization such as specifying different kinds of transformer_layer in TransformerLayerSequence

Details

1 More unified data flow

Design a unified data flow for vision transformer and add **kwargs to adapt to other transformers (e.g., the level_start_index,reference_pointsin ...deformable-detr) .

query (Tensor): Input query with the shape
    `(num_query, bs, embed_dims)`.
key (Tensor): The key tensor with shape
    `(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
    `(num_key, bs, embed_dims)`.
query_pos (Tensor): The positional encoding for `query`.
    Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
    None.
attn_masks (list[Tensor], optional):  2D Tensor used in the calculation of corresponding attention.
    Defaults: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
    shape [bs, num_query]. Only used in self-attention
key_padding_mask (Tensor): ByteTensor for `query`, with
    shape [bs, num_key].

2 BaseTransformerLayer

It uses the unified data flow and can be initialized by basic arguments of all TransformerLayer such as

        attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
            Configs for self_attention or cross_attention, the order
            should be consistent with it in `operation_order`. If it is
            a dict, it would be expanded to the number of attention in
            `operation_order`.
        feedforward_channels (int): The hidden dimension for FFNs.
        ffn_dropout (float): Probability of an element to be zeroed
            in ffn. Default 0.0.
        operation_order (tuple[str]): The execution order of operation
            in transformer. Such as ('selfattn', 'norm', 'ffn', 'norm').
            Support prenorm when you specifying first element as `norm`.
            DefaultNone
        act_cfg (dict): The activation config for FFNs.
        norm_cfg (dict): Config dict for normalization layer.
        ffn_num_fcs (int): The number of fully-connected layers in FFNs.
            Default2.

The forward function can be used in both TransformerEncodeLayer and TransformerDecodeLayer by specifying different operation_order. It supports pre_norm when you specifying the first operation as norm in operation_order.

3 TransformerLayerSequence

It uses the unified data flow and can be initialized by basic arguments of all TransformerLayerSequence such as

        transformerlayer (list[obj:`mmcv.ConfigDict`] |
            obj:`mmcv.ConfigDict`): Config of transformerlayer
            in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
             it would be repeated `num_layer` times to a
             list[`mmcv.ConfigDict`]. Default: None.
        num_layers (int): The number of `TransformerLayer`. Default: None.

Usages

  1. TransformerLayer

you can build the TransformerEncoderLayer by giving a ConfigDict which looks like

  
         codelayers=dict(
                type='BaseTransformerLayer',
                attn_cfgs=dict(
                    type='MultiheadAttention',
                    embed_dims=256,
                    num_heads=8,
                    dropout=0.1),
                feedforward_channels=2048,
                ffn_dropout=0.1,
                operation_order=('selfattn', 'norm', 'ffn', 'norm'),
            )
    # or more general,  you can specify different attentions in `TransformerLayer`.
            codelayers=dict(
                type='BaseTransformerLayer',
                attn_cfgs=[
                    dict(
                    	type='Attention_1',
                    	embed_dims=256,
                    	num_heads=8,
                    	dropout=0.1,
                    	other_args = xxx,
                        ...
                    ),
                    dict(
                    	type='Attention_2',
					    args=xxx,
                        ......
                    )     
                ],
                feedforward_channels=2048,
                ffn_dropout=0.1,
                operation_order=('selfattn', 'norm','selfattn', 'ffn', 'norm'),
            )
        
   # you can build the layer by
    from mmcv.cnn.bricks.transformer import build_transformerlayer
	build_transformerlayer(codelayers)
        
  1. For subclass of TransformerLayerSequence, you can build it by
            encoder=dict(
                type='DetrTransformerEncoder',
                num_layers=6,
                    transformerlayers=dict(
                        type='BaseTransformerLayer',
                        attn_cfgs=[dict(
                            type='MultiheadAttention',
                            embed_dims=256,
                            num_heads=8,
                            dropout=0.1)],
                        feedforward_channels=2048,
                        ffn_dropout=0.1,
                        operation_order=('self_attn', 'norm', 'ffn', 'norm'),
                    )
            ),

# or more general, you can specify different layers in the coder.
            encoder=dict(
                type='DetrTransformerEncoder',
                num_layers=2,
                    transformerlayers=dict(
                        type='BaseTransformerLayer',
                        attn_cfgs=[dict(
                            type='MultiheadAttention',
                            embed_dims=256,
                            num_heads=8,
                            dropout=0.1),
						dict(
                        	type='BaseTransformerLayer',
                        	attn_cfgs=[dict(
                            	type='xxx;ayer',
                            	args=xxx)
                         ],
                        feedforward_channels=2048,
                        ffn_dropout=0.1,
                        operation_order=('self_attn', 'norm', 'ffn', 'norm'),
                    )
            )
# you can build it by
from mmcv.cnn.bricks.transformer import build_transformercoder
build_transformerlayer(encoder)

BC-breaking

None

@codecov
Copy link

codecov bot commented Mar 17, 2021

Codecov Report

Merging #892 (bc869e4) into master (73bff4e) will decrease coverage by 1.73%.
The diff coverage is 16.71%.

❗ Current head bc869e4 differs from pull request most recent head 84c42d9. Consider uploading reports for the commit 84c42d9 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #892      +/-   ##
==========================================
- Coverage   66.58%   64.84%   -1.74%     
==========================================
  Files         145      148       +3     
  Lines        8828     9140     +312     
  Branches     1605     1644      +39     
==========================================
+ Hits         5878     5927      +49     
- Misses       2633     2897     +264     
+ Partials      317      316       -1     
Flag Coverage Δ
unittests 64.84% <16.71%> (-1.74%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/cnn/bricks/transformer.py 0.00% <0.00%> (ø)
mmcv/ops/upfirdn2d.py 15.29% <15.29%> (ø)
mmcv/ops/fused_bias_leakyrelu.py 30.90% <30.90%> (ø)
mmcv/runner/base_module.py 79.41% <71.42%> (-6.31%) ⬇️
mmcv/cnn/bricks/conv_module.py 100.00% <100.00%> (ø)
mmcv/cnn/bricks/registry.py 100.00% <100.00%> (ø)
mmcv/cnn/utils/flops_counter.py 93.63% <100.00%> (+0.45%) ⬆️
mmcv/ops/__init__.py 100.00% <100.00%> (ø)
mmcv/runner/__init__.py 100.00% <100.00%> (ø)
... and 2 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 00870b9...84c42d9. Read the comment docs.

Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

mmcv/cnn/bricks/registry.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more comments

mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additional comments

mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another comments.

mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
@jshilong jshilong requested a review from nbei March 18, 2021 09:12
Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/transformer.py Outdated Show resolved Hide resolved
@jshilong jshilong requested review from nbei and ZwwWayne March 24, 2021 06:24
@ZwwWayne ZwwWayne merged commit a9803da into open-mmlab:master Mar 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants