Skip to content

Commit

Permalink
fix some bugs for latest TE (#160)
Browse files Browse the repository at this point in the history
**Description**
Fix some bugs for latest TE and add UT for it.
1. In TE, it only allocates fp8 weight for the first micro batch. In
MS-AMP, it allocates zero size tensor for fp8 weight because
`tex.fp8_cast_transpose_fused` will allocate memory for it. However,
latest TE introduces a data structure `Float8Tensor` which use `_data`
to store the original fp8 tensor. When comparing shape in
`set_fp8_weights`, we should use the shape of `_data`. Otherwise, TE
will allocate zero-size tensor for non-first micro batch.
2. Seem that when using latest TE, Megaton-LM can't converge(Test it
with GPT-345M). The newest TE which can converge is v1.1, so convert it
back to v1.1
  • Loading branch information
tocean committed Feb 22, 2024
1 parent 0a28e0f commit 9ac98df
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion dockerfile/torch1.14-cuda11.8.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ RUN cd third_party/msccl && \
make install
# cache TE build to save time in CI
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.1

ADD . .
RUN python3 -m pip install . && \
Expand Down
2 changes: 1 addition & 1 deletion dockerfile/torch2.1-cuda12.2.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ RUN cd third_party/msccl && \
make install
# cache TE build to save time in CI
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.1

ADD . .
RUN python3 -m pip install . && \
Expand Down
2 changes: 1 addition & 1 deletion msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def set_fp8_weights(self):
weight_cast_attr = f'weight{i}_fp8'
weight_transpose_attr = f'weight{i}_t_fp8'

if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr).shape == shape):
if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr)._data.shape == shape):
return

setattr(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ classifiers=[
]
dependencies = [
"torch",
"transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@stable",
"transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@v1.1",
"colorlog>=6.7.0",
"deepspeed==0.13.1",
"mpi4py",
Expand Down
5 changes: 4 additions & 1 deletion tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,8 @@ def test_fp8_ddp_with_te(self):
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(x, attention_mask=None)
output = model(x, attention_mask=None, is_first_microbatch=True)
output.sum().backward()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(x, attention_mask=None, is_first_microbatch=False)
output.sum().backward()

0 comments on commit 9ac98df

Please sign in to comment.