Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add casual vqvae ✨ #145

Merged
merged 1 commit into from
Mar 18, 2024
Merged

Conversation

qqingzheng
Copy link
Contributor

@qqingzheng qqingzheng commented Mar 16, 2024

TODO

  • Codebase
  • Train a test version
  • Do more validation

Casual VQVAE

The structure optimization of vqvae in VideoGPT is as follows:

  1. Replace regular Conv with CausalConv. (Repeat the first frame k-1 times, where k represents the temporal dimension size of the convolutional kernel, thereby independently isolating the first frame as an image.)
  2. Use interpolation plus convolution for upsampling.
  3. Separate temporal compression from spatial compression. Spatial compression is achieved with 2D convolution, while temporal compression is achieved with 3D convolution (padding the space to ensure the spatial size remains unchanged).
  4. Incorporate Axial Attention on the temporal dimension with a Causal Mask to ensure that the first frame and other frames are not associated, and that each frame is only related to the current and previous frames.

Examples(4000 steps ckpt)

Image Reconstruction

Small Model Large Model

Video Reconstruction

Small Model Large Model

@qqingzheng qqingzheng marked this pull request as draft March 16, 2024 18:56
@qqingzheng qqingzheng marked this pull request as ready for review March 18, 2024 01:46
@LinB203 LinB203 merged commit 9d25001 into PKU-YuanGroup:main Mar 18, 2024
@qqingzheng qqingzheng deleted the add_casual_vqvae branch March 27, 2024 12:52
@alvinliu0
Copy link

Hi,

Many thanks for your great contribution! May I ask the reasons of step1 structure optimization of video vqvae:

[Replace regular Conv with CausalConv. (Repeat the first frame k-1 times, where k represents the temporal dimension size of the convolutional kernel, thereby independently isolating the first frame as an image.)]

I understand that with such padding, the temporal conv kernel will only see replicates of the first frame in the first temporal convolution operation, but what's the significance for doing so? For example, if the temporal conv kernel size is 4, and frame number is 4, stride is 2, the convoluted sequence would be like:

replicated frame id: 1, 1, 1, 1, 2, 3, 4, pad
convoluted frame id: [1, 1, 1, 1], [1, 1, 2, 3], [2, 3, 4, pad]

which is unclear to me why the casual video vae should work like this. Besides, how do you extend the pretrained spatial vqvae into this casual video one?

Thanks for your explanation in advance!

Best

@qqingzheng
Copy link
Contributor Author

For the first question, if the first frame is copied k-1 times and concatenated, then no matter how many convolutional layers it goes through, the first frame will exist independently of the other frames and will not blend with them. In this way, the first frame is treated as an independent image.

For example:

Input: (1,2,3,4) -> (1,1,1,1,2,3,4,0)

Conv1: ([1,1,1,1], [1,1,2,3], [2,3,4,0]) is denoted as (a,b,c) -> (a,a,a,a,b,c)

Conv2: ([a,a,a,a], [a,a,b,c])

Despite undergoing two layers of convolution, it can be observed that the computation from 1 to "a" is actually independent of other frames, and this holds true regardless of the number of layers of convolution.

@qqingzheng
Copy link
Contributor Author

For the second question, you can see #168.

@alvinliu0
Copy link

Many thanks for your explanation, which makes much sense to me!

May I ask a following question: with such designs, how do you enable the joint image-video training, or what's the significance of an independent first frame.

For example, if you have both an image and video dataset, and sample i_0 from image dataset, and i_1, ..., i_4 from video dataset. Do you mean the training batch is replicated as: [i_0, i_0, i_0, i_0, i_1, i_2, i_3, i_4]? I guess not, since in this way, the convolution of [i_0, i_0, i_1, i_2] would be meaningless as they are from different datasets.

If the first frame just comes from the first frame of video sequence, then we can not benefit from the large-scale image dataset.

Many thanks!

@qqingzheng
Copy link
Contributor Author

Many thanks for your explanation, which makes much sense to me!

May I ask a following question: with such designs, how do you enable the joint image-video training, or what's the significance of an independent first frame.

For example, if you have both an image and video dataset, and sample i_0 from image dataset, and i_1, ..., i_4 from video dataset. Do you mean the training batch is replicated as: [i_0, i_0, i_0, i_0, i_1, i_2, i_3, i_4]? I guess not, since in this way, the convolution of [i_0, i_0, i_1, i_2] would be meaningless as they are from different datasets.

If the first frame just comes from the first frame of video sequence, then we can not benefit from the large-scale image dataset.

Many thanks!

Currently, we have not yet attempted to conduct mixed training of images and videos in AE. At present, we still treat the first frame of the video as an image and use video datasets for training. However, we have adopted a mixed form of images and videos in the training of DiT, and you can find the relevant code in opensora/train/train_t2v.py.

@ivylilili
Copy link

Hi,I have a question about the structure design in the CausalVAE. TimeUpsample2x since it both can accept the image input and video input, for image, the 'TimeUpsample2x' will not upsample time dimension. But for video input, if the frames number is 4 (since the time compression is 4), then the latent gets 1 channel in the time axis, which will be treated as 'image' when going through 'TimeUpsamle2x' and still output 1 channel in the time dimension?

@qqingzheng
Copy link
Contributor Author

Hi,I have a question about the structure design in the CausalVAE. TimeUpsample2x since it both can accept the image input and video input, for image, the 'TimeUpsample2x' will not upsample time dimension. But for video input, if the frames number is 4 (since the time compression is 4), then the latent gets 1 channel in the time axis, which will be treated as 'image' when going through 'TimeUpsamle2x' and still output 1 channel in the time dimension?

Yes, currently this is a not very good feature of our CausalVAE. If you have better suggestions, welcome PR.

@Birdylx
Copy link

Birdylx commented Apr 2, 2024

@qqingzheng hi, may I ask a question about time compress and loss funciton for the latest CausalVAE

  1. there is no parameters in TimeDownsample2x and TimeUpsample2x, is that enough to learn the time compression?
  2. it seems you only compute the perceputal loss and recon loss, no gan loss involved?

@qqingzheng
Copy link
Contributor Author

qqingzheng commented Apr 2, 2024

@qqingzheng hi, may I ask a question about time compress and loss funciton for the latest CausalVAE

  1. there is no parameters in TimeDownsample2x and TimeUpsample2x, is that enough to learn the time compression?
  2. it seems you only compute the perceputal loss and recon loss, no gan loss involved?

Thank you for your question. Your question is very crucial.

For the first question, we also believe that it is insufficient. This is a compromise solution we adopted to maximize the utilization of the pretrain weights of the SD VAE. In our subsequent training, we also tried some methods to solve this problem, such as using special initialization methods to initialize the 3D convolution kernel.

For the second question, GAN may be difficult to train, and we hope to iteratively train the model from simple to complex. Therefore, we have temporarily removed GAN to accelerate our experimental speed and cost.

@Sutongtong233
Copy link

For the second question, you can see #168.

The encoder and decoder in this repository differ from yours. Perhaps you could provide your implementation of the inflate 2DVAE? Alternatively, I can contribute my implementation.

@Birdylx
Copy link

Birdylx commented Apr 5, 2024

@qqingzheng hi, may I ask a question about time compress and loss funciton for the latest CausalVAE

  1. there is no parameters in TimeDownsample2x and TimeUpsample2x, is that enough to learn the time compression?
  2. it seems you only compute the perceputal loss and recon loss, no gan loss involved?

Thank you for your question. Your question is very crucial.

For the first question, we also believe that it is insufficient. This is a compromise solution we adopted to maximize the utilization of the pretrain weights of the SD VAE. In our subsequent training, we also tried some methods to solve this problem, such as using special initialization methods to initialize the 3D convolution kernel.

For the second question, GAN may be difficult to train, and we hope to iteratively train the model from simple to complex. Therefore, we have temporarily removed GAN to accelerate our experimental speed and cost.

@qqingzheng I add an identity init in CausalConv3d which acts like Identity function. BTW, I add gan loss in training, it actually works well, if only perceptual loss used, there are many grid effect in reconstructed image, gan loss will remove these grid and generate more realistic image

@qqingzheng
Copy link
Contributor Author

@qqingzheng hi, may I ask a question about time compress and loss funciton for the latest CausalVAE

  1. there is no parameters in TimeDownsample2x and TimeUpsample2x, is that enough to learn the time compression?
  2. it seems you only compute the perceputal loss and recon loss, no gan loss involved?

Thank you for your question. Your question is very crucial.
For the first question, we also believe that it is insufficient. This is a compromise solution we adopted to maximize the utilization of the pretrain weights of the SD VAE. In our subsequent training, we also tried some methods to solve this problem, such as using special initialization methods to initialize the 3D convolution kernel.
For the second question, GAN may be difficult to train, and we hope to iteratively train the model from simple to complex. Therefore, we have temporarily removed GAN to accelerate our experimental speed and cost.

@qqingzheng I add an identity init in CausalConv3d which acts like Identity function. BTW, I add gan loss in training, it actually works well, if only perceptual loss used, there are many grid effect in reconstructed image, gan loss will remove these grid and generate more realistic image

Thank you for your work! We have also discovered the importance of GAN Loss in our subsequent experiments and will synchronize our latest code to the repository soon.

LinB203 added a commit that referenced this pull request Apr 6, 2024
LinB203 added a commit that referenced this pull request Apr 9, 2024
LinB203 added a commit that referenced this pull request Apr 9, 2024
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: f034a4c
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: 8f9c1b0 [formerly 0c897b2044dcae19f87ce7d599cc491781a23758] [formerly 498da7d1e508948b9485cd613639209765a7aa3d [formerly f034a4c]]
Former-commit-id: 31731adf277e997825a1ebdd1e81141d84c4a9ce [formerly c0be1b0]
Former-commit-id: ac0beb0c7b87145f8c02545e98a8458b30e4b21b
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: 3062697d868c34beab27d9fe5ab5750e18a73135 [formerly f034a4c]
Former-commit-id: b33374ec596ee281e1358b194dce3571f9a1b78f
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: 52c86a0b752a192e6201c7e8a37b83ff27f2dd00 [formerly b3bb071bcb6393b40f7acaf6500ae254976316d1]
Former-commit-id: de95e6076dba1a1b190646a463526a9add5ad156
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: 8f9c1b0 [formerly 0c897b2044dcae19f87ce7d599cc491781a23758] [formerly 498da7d1e508948b9485cd613639209765a7aa3d [formerly f034a4c]]
Former-commit-id: 31731adf277e997825a1ebdd1e81141d84c4a9ce [formerly c0be1b0]
Former-commit-id: ac0beb0c7b87145f8c02545e98a8458b30e4b21b
Former-commit-id: 6bd4e74
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: f034a4c
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: f034a4c
LinB203 added a commit that referenced this pull request Apr 9, 2024
[feat] add casual vqvae ✨

Former-commit-id: f034a4c
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants