Skip to content

Commit

Permalink
clean2
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Sep 19, 2024
1 parent e40b0c7 commit 9226555
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ class GPTConfig(TransformerConfig, io.IOMixin):
data_step_fn: Callable = gpt_data_step

def configure_model(self, tokenizer) -> "MCoreGPTModel":
if self.enable_cuda_graph:
assert not self.gradient_accumulation_fusion, \
"gradient_accumulation_fusion is not supported with enable_cuda_graph"

vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
Expand Down Expand Up @@ -267,11 +263,10 @@ def configure_model(self) -> None:
'use_te_rng_tracker', False
), "Transformer engine's RNG tracker is required for cudagraphs, it can be "\
"enabled with use_te_rng_tracker=True'."
assert not self.config.gradient_accumulation_fusion, \
"gradient_accumulation_fusion is not supported with enable_cuda_graph"

self.module = self.config.configure_model(self.tokenizer)
if self.config.enable_cuda_graph and self.training:
assert not self.config.cpu_offloading and self.config.recompute_granularity is None, \
"Cudagraphs is not supported with cpu_offloading/recompute_granularity"
self.add_module('cudagraph_manager', CudaGraphManager())

def forward(
self,
Expand Down

0 comments on commit 9226555

Please sign in to comment.