Skip to content

Commit

Permalink
add onnx namespace for custom ops (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
q.yao committed Aug 16, 2021
1 parent 54907a3 commit 44edcdd
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 35 deletions.
16 changes: 8 additions & 8 deletions mmcv/ops/carafe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ class CARAFENaiveFunction(Function):
@staticmethod
def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
return g.op(
'MMCVCARAFENaive',
'mmcv::MMCVCARAFENaive',
features,
masks,
kernel_size=kernel_size,
group_size=group_size,
scale_factor=scale_factor)
kernel_size_i=kernel_size,
group_size_i=group_size,
scale_factor_f=scale_factor)

@staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
Expand Down Expand Up @@ -102,12 +102,12 @@ class CARAFEFunction(Function):
@staticmethod
def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
return g.op(
'MMCVCARAFE',
'mmcv::MMCVCARAFE',
features,
masks,
kernel_size=kernel_size,
group_size=group_size,
scale_factor=scale_factor)
kernel_size_i=kernel_size,
group_size_i=group_size,
scale_factor_f=scale_factor)

@staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/cc_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CAWeightFunction(torch.autograd.Function):

@staticmethod
def symbolic(g, t, f):
return g.op('MMCVCAWeight', t, f)
return g.op('mmcv::MMCVCAWeight', t, f)

@staticmethod
def forward(ctx, t, f):
Expand All @@ -41,7 +41,7 @@ class CAMapFunction(torch.autograd.Function):

@staticmethod
def symbolic(g, weight, v):
return g.op('MMCVCAMap', weight, v)
return g.op('mmcv::MMCVCAMap', weight, v)

@staticmethod
def forward(ctx, weight, v):
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/deform_roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class DeformRoIPoolFunction(Function):
def symbolic(g, input, rois, offset, output_size, spatial_scale,
sampling_ratio, gamma):
return g.op(
'MMCVDeformRoIPool',
'mmcv::MMCVDeformRoIPool',
input,
rois,
offset,
pooled_height=output_size[0],
pooled_width=output_size[1],
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio,
gamma=gamma)
pooled_height_i=output_size[0],
pooled_width_i=output_size[1],
spatial_scale_f=spatial_scale,
sampling_ratio_f=sampling_ratio,
gamma_f=gamma)

@staticmethod
def forward(ctx,
Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class SigmoidFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
return g.op(
'MMCVSigmoidFocalLoss',
'mmcv::MMCVSigmoidFocalLoss',
input,
target,
gamma=gamma,
alpha=alpha,
weight=weight,
reduction=reduction)
gamma_f=gamma,
alpha_f=alpha,
weight_f=weight,
reduction_s=reduction)

@staticmethod
def forward(ctx,
Expand Down Expand Up @@ -111,13 +111,13 @@ class SoftmaxFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
return g.op(
'MMCVSoftmaxFocalLoss',
'mmcv::MMCVSoftmaxFocalLoss',
input,
target,
gamma=gamma,
alpha=alpha,
weight=weight,
reduction=reduction)
gamma_f=gamma,
alpha_f=alpha,
weight_f=weight,
reduction_s=reduction)

@staticmethod
def forward(ctx,
Expand Down
6 changes: 3 additions & 3 deletions mmcv/ops/masked_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class MaskedConv2dFunction(Function):
@staticmethod
def symbolic(g, features, mask, weight, bias, padding, stride):
return g.op(
'MMCVMaskedConv2d',
'mmcv::MMCVMaskedConv2d',
features,
mask,
weight,
bias,
padding=padding,
stride=stride)
padding_i=padding,
stride_i=stride)

@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
Expand Down
5 changes: 4 additions & 1 deletion mmcv/ops/psa_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ class PSAMaskFunction(Function):
@staticmethod
def symbolic(g, input, psa_type, mask_size):
return g.op(
'MMCVPSAMask', input, psa_type=psa_type, mask_size=mask_size)
'mmcv::MMCVPSAMask',
input,
psa_type_i=psa_type,
mask_size_i=mask_size)

@staticmethod
def forward(ctx, input, psa_type, mask_size):
Expand Down
10 changes: 5 additions & 5 deletions mmcv/ops/sync_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ class SyncBatchNormFunction(Function):
def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size):
return g.op(
'MMCVSyncBatchNorm',
'mmcv::MMCVSyncBatchNorm',
input,
running_mean,
running_var,
weight,
bias,
momentum=momentum,
eps=eps,
group=group,
group_size=group_size)
momentum_f=momentum,
eps_f=eps,
group_i=group,
group_size_i=group_size)

@staticmethod
def forward(self, input, running_mean, running_var, weight, bias, momentum,
Expand Down

0 comments on commit 44edcdd

Please sign in to comment.