Skip to content

Commit

Permalink
fix parrots op bug
Browse files Browse the repository at this point in the history
  • Loading branch information
luopeichao committed Aug 23, 2021
1 parent ac0d839 commit 1ec3b7a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mmcv/onnx/info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import torch


def is_custom_op_loaded():
flag = False
Expand All @@ -16,4 +18,4 @@ def is_custom_op_loaded():
flag = os.path.exists(ort_lib_path)
except (ImportError, ModuleNotFoundError):
pass
return flag
return flag or torch.__version__ == 'parrots'
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/parrots/sync_bn_parrots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx,
auto running_var = buildATensor(ctx, outs[1]);
auto norm = buildATensor(ctx, outs[2]);
auto std = buildATensor(ctx, outs[3]);
auto output = buildATensor(ctx, outs[3]);
auto output = buildATensor(ctx, outs[4]);
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
weight, bias, norm, std, output, eps, momentum,
group_size);
Expand Down

0 comments on commit 1ec3b7a

Please sign in to comment.