Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensitivity Effects Near Grid Boundaries (Experimental Results) ~+1 AP@[.5, .95] #3293

Closed
glenn-jocher opened this issue May 31, 2019 · 47 comments

Comments

@glenn-jocher
Copy link

@AlexeyAB I ran an interesting experiment recently. Using the iDetection app, I set up an iPhone to view a street for one hour and record the all the detections for later analysis. This recorded about 400,000 detections over 100,000 video frames (3600s at an average of 20 FPS). The model used was https://pjreddie.com/media/files/yolov3-spp.weights, exported to PyTorch > ONNX > CoreML in a 192 x 320 width-height shape, with 6x10, 12x20, 24x40 grids. The results worked amazingly well, but it also uncovered a effect I'd never noticed before.

In my layover below, you can actually see the YOLOv3 grids, because for some reason there are no detections near the grid boundaries (those histogram cells are all zero).

Also equally fascinating you can see that the middle grid is used for up-close pedestrians, because I can count 12 grids across nearest to the camera, while you can actually visualize the transition to the largest 24-across grid far away.

My question to you is have you ever seen any issues at the grid intersection areas before, such as reduced recall? There are a lot of effects at play here, so the cause may not be in the yolov3-spp.weights themselves, but perhaps the PyTorch inference model, or the CoreML export. I'm fairly confident that the PyTorch and Darknet inference is practically identical however, due to identical test mAPs.

Single Frame 2D Histogram over 1 hour
IMG_4519 IMG_4519_2
@AlexeyAB
Copy link
Owner

@glenn-jocher Hi,

This is a very interesting note.

I have never visualized the distribution of detections in such a way, and believed that detections worsen but no too much on the grid boundaries.

But I have encountered the problem of flickering detections (blinking issue) during detection on Video. And this problem can't be completely solved even by using recurrent LSTM-networks #3114 (comment) , apparently because they also use the same grid at the end.

The solution may be to use two offset grids for each scale, even if each of the grids will have fewer cells in order to maintain the same processing speed - I will think about it.

Can you show the distribution of the frequency of detection of objects on the graph, for example, show the number of detections on the graph, where X - is the same as the X coordinate in the figure (red-line on the image), and Y - is the number of object detections?

image


Also equally fascinating you can see that the middle grid is used for up-close pedestrians, because I can count 12 grids across nearest to the camera, while you can actually visualize the transition to the largest 24-across grid far away.

Yes, for larger bboxes is used [yolo]-layer with lower number of cells.
For smaller bboxes is used [yolo]-layer with higher number of cells.

@glenn-jocher
Copy link
Author

glenn-jocher commented May 31, 2019

Yes, if I plot 1D histograms of the medium grid area, and the small grid area they look like this. I thought about this some more. I've seen flickering in the iDetection app also, I figure it might be related to the FP16 or FP8 quantization in the final CoreML model, but perhaps there's a grid crossing sensitivity at play also.

Now that I think about it, it makes sense, because the network has to output some pretty large numbers (negative or positive) in order for a sigmoid of that output to really get close to 0 or 1 (a grid boundary). So perhaps all of the people are detected confidently, its just that their xy errors are higher near the grid boundaries (biased to the center rather than the edges). Perhaps the flickering is unrelated.

Yes, an offset grid might solve this, or maybe not. Maybe a simpler solution would be to train on objects from -.25 to 1.25 grids away (compared to 0-1 now), then each grid would provide redundant overlap with its neighboring grids, and NMS would sort out the inevitable duplicates?
small_grid
medium_grid

@AlexeyAB
Copy link
Owner

AlexeyAB commented May 31, 2019

@glenn-jocher

I've seen flickering in the iDetection app also, I figure it might be related to the FP16 or FP8 quantization in the final CoreML model, but perhaps there's a grid crossing sensitivity at play also.

Can you disable this quantization temporary for testing?

Yes, the less (bits) precision I use FP32->FP16->INT8->BIT1 the greater the problem of blinks.


Now that I think about it, it makes sense, because the network has to output some pretty large numbers (negative or positive) in order for a sigmoid of that output to really get close to 0 or 1 (a grid boundary). So perhaps all of the people are detected confidently, its just that their xy errors are higher near the grid boundaries (biased to the center rather than the edges). Perhaps the flickering is unrelated.

Yes, may be there is no flickering.
Just there is simply a small coordinate error (shifted to zero).
We can fix it almost completely just by using 2 * logistic - 0.5 (or even 1.5 * logistic - 0.25) instead of logistic during training and detection.

image


Yes, an offset grid might solve this, or maybe not. Maybe a simpler solution would be to train on objects from -.25 to 1.25 grids away (compared to 0-1 now), then each grid would provide redundant overlap with its neighboring grids, and NMS would sort out the inevitable duplicates?

Do you mean to add there:

darknet/src/yolo_layer.c

Lines 269 to 270 in 55dcd1b

}
}

something like this?

{
    for (t = 0; t < l.max_boxes; ++t) {
        box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
        if (truth.x < 0 || truth.y < 0 || truth.x > 1 || truth.y > 1 || truth.w < 0 || truth.h < 0) {
            printf(" Wrong label: truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f \n", truth.x, truth.y, truth.w, truth.h);
        }
        int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
        if (l.map) class_id = l.map[class_id];
        if (class_id >= l.classes) continue; // if label contains class_id more than number of classes in the cfg-file

        float obj_i = (truth.x * l.w);
        float obj_j = (truth.y * l.h);

        float best_iou = 0;
        float best_n = 0;
        box truth_shift = truth;
        truth_shift.x = truth_shift.y = 0;
        for (n = 0; n < l.total; ++n) {
            box pred = { 0 };
            pred.w = l.biases[2 * n] / state.net.w;
            pred.h = l.biases[2 * n + 1] / state.net.h;
            float iou = box_iou(pred, truth_shift);
            if (iou > best_iou) {
                best_iou = iou;
                best_n = n;
            }
        }
        int mask_n = int_index(l.mask, best_n, l.n);

        if (mask_n >= 0 && (fabs(i - obj_i) < 1.5) && (fabs(j - obj_j) < 1.5)) {
            int obj_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4);
            l.delta[obj_index] = l.cls_normalizer * (1 - l.output[obj_index]);

            int class_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4 + 1);
            delta_yolo_class(l.output, l.delta, class_index, class_id, l.classes, l.w*l.h, 0, l.focal_loss);

            int box_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 0);
            delta_yolo_box(truth, l.output, l.biases, best_n, box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.delta, (2 - truth.w*truth.h), l.w*l.h, l.iou_normalizer, l.iou_loss);
        }
    }
}

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 1, 2019

I think should be used: 2 * logistic - 0.5

It seems that the distance between the stable detections for a grid with a large number of cells is 50%.
58734682-6750fe00-83f8-11e9-8100-9b30feca897d

While for a grid with a small number of cells, the distance is only 25%.
58734685-67e99480-83f8-11e9-8f06-138006d77487

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 1, 2019

I got +2% [email protected] on yolov3-spp-pan-xnor.cfg (without online-SVR) on several classes from Cityscapes dataset:

So may be it is a good solution, will test more.

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 1, 2019

Oh wow that's huge!! Thats a 10% increase! 2 * logistic - 0.5 will allow the detections to span the gridspace from -0.5 to 1.5. Yes, that seems perfect. It is odd as you note that the gap is constant in pixels rather than grid units.

I realized I should be able to repeat this test very simply on 5k.val using yolov3-spp.weights and plotting the same 2D histogram. This will remove the whole chain of uncertainty created by the quantizing and exporting to CoreML. I'll do that now.

@glenn-jocher
Copy link
Author

Ok, I have the 5k.val test results here at 416 resolution. This plot below shows the xy centers for all detections (551,000 boxes). Yes, I can confirm the same phenomenon appears there as well.

python3 test.py --save-json --img-size 416
Namespace(batch_size=16, cfg='cfg/yolov3-spp.cfg', conf_thres=0.001, data_cfg='data/coco.data', img_size=416, iou_thres=0.5, nms_thres=0.5, save_json=True, weights='weights/yolov3-spp.weights')
Using CUDA device0 _CudaDeviceProperties(name='Tesla V100-SXM2-16GB', total_memory=16130MB)

               Class    Images   Targets         P         R       mAP        F1
Computing mAP: 100%|█████████████████████████████████████████████████████████████| 313/313 [06:59<00:00,  1.14s/it]
                 all     5e+03  3.58e+04     0.104     0.747     0.552     0.178

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.335
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.563
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.347
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.151
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.359
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.493
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.280
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.432
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.459
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.254
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.496
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.620

saved

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 1, 2019

1D histograms. There are 52 tall peaks across. The gaps seem less pronounced here than before. This is for all classes in COCO, whereas my street results are for only people. The street results were at 192x320 also, with FP16.

predictions

5k.val labels x and y distributions for reference.
labels

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 1, 2019

@glenn-jocher
Can you clarify, are all 4 charts for 5k.val?
And what is the difference between top and bottom charts?

@glenn-jocher
Copy link
Author

Ah, the top two are the predicted x and y for 5k.val during testing, which produces the 0.563 mAP. The bottom two are the ground truth x and y distributions. Ideally our top 2 charts would match the bottom 2.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 4, 2019

@glenn-jocher

Since, in more general task 5k.val the distance between peaks less than just for people detection, and the relative distance between peaks slightly more for yolo-layers for small objects than for large, so I used

  • 1.05 * logistic - 0.025 - for yolo layer (large objects) scale_x_y = 1.05 in cfg-file
  • 1.1 * logistic - 0.05 - for yolo layer (medium objects) scale_x_y = 1.1 in cfg-file
  • 1.2 * logistic - 0.1 - for yolo layer (small objects) scale_x_y = 1.2 in cfg-file

by using such function: eac2622#diff-180a7a56172e12a8b79e41ec95ae569dR557
scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1);

I got +1.9% mAP on yolo_v3_spp_pan_scale.cfg.txt compared to yolo_v3_spp_pan.cfg.txt: #3114 (comment)

  • yolo_v3_spp_pan_scale.cfg.txt - 60.4% mAP (scale_x_y = 1.05, 1.1, 1.2 for different [yolo]-layers)

  • yolo_v3_spp_pan.cfg.txt - 58.5% mAP

  • yolo_v3_spp_pan_scale.cfg.txt - ~55% mAP (scale_x_y = 1.25, 1.5, 2.0 for different [yolo]-layers, the same as scale_x_y = 2.0, 2.0, 2.0)

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 5, 2019

@AlexeyAB very interesting. So scale_x_y = 2.0 actually reduced mAP in this latest example but helped with the cityscape dataset. Is this new yolo_v3_spp_pan_scale.cfg file usable with the master branch? I can try some tests of my own as well.

I was wondering about the possibility of extending the sensitivity itself (not just the activation function range) of a grid cell, so for example training on objects with xy centers in the range of -.1 to 1.1, with an activation function range extending perhaps a bit past that, say -.2 to 1.2. This might further ease the 'handoff' between one grid cell and the next by producing redundant duplicate detections at the border, which NMS would sort out. But I'm unclear if that's a viable possibility, if a neighbouring cell has any observability into an object not centered in it.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 5, 2019

@glenn-jocher

Yes, you can use yolo_v3_spp_pan_scale.cfg with the last commit from this repository.
Also I tested scale_x_y = 1.05, 1.1, 1.2 with yolo_v3_tiny_pan_lstm.cfg.txt and it decreases the mAP.

I haven't tested much yet, so maybe it helps only some models and only on some datasets, or preferably on low-precision INT8/BIT1, or maybe it's just the fluctuations of the mAP.

I was wondering about the possibility of extending the sensitivity itself (not just the activation function range) of a grid cell, so for example training on objects with xy centers in the range of -.1 to 1.1

To actively use this logic, you must add this code: #3293 (comment)
But it will increase the competition of objects for the cell - and this is bad.

For passive logic, you can simply reduce to ~0.5

darknet/cfg/yolov3.cfg

Lines 786 to 787 in abba310

ignore_thresh = .7
truth_thresh = 1

@AlexeyAB AlexeyAB changed the title Sensitivity Effects Near Grid Boundaries (Experimental Results) Sensitivity Effects Near Grid Boundaries (Experimental Results) ~+1 AP@[.5, .95] Jun 7, 2019
@AlexeyAB
Copy link
Owner

@glenn-jocher Hi,

Did you try to train yolov3-spp model with scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1); (or other improvements) in your repository?
Did you get any benefits?

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 19, 2019

@AlexeyAB I tried a few tricks, but I stopped because I realized I need to establish a baseline training comparison to AlexeyAB/darknet first, which I still haven't been able to do on coco2014. The ultralytics/yolov3 repo has 3 main parts:

  • detect.py validated against darknet using various models (like yolov3, spp, tiny)
  • test.py validated against pycocotools to within 1% across a wide range of settings
  • train.py not yet working as well as AlexeyAB/darknet on coco2014

So I've simply been trying to reproduce the AlexeyAB/darknet training settings on yolov3-spp first, but I am still about 10% below on both [email protected] and [email protected]. The last training I ran ended with the results below, and prototyping changes to the training is very slow because GCP keeps killing my preemptable V100 instances after only a few hours. I think I should raise a new issue here so you could help me try and understand the differences, though a few things I simply haven't had time to implement yet, like ignore_threshold.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.248
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.464

BUT, I did have two ideas related to this:

  1. I read the FCOS paper and later when observing the grid sensitivity it occurred to me that if we could reformulate our regression space from [x, y, w, h] to the FCOS regression space of [l, r, t, b] then this should theoretically completely eliminate the grid sensitivity issue. I realize this raises it's own issues though around anchors, and as I understand it one of the jumps in performance from YOLOv1 to v2 was the introduction of anchors, so its not a cut and dry solution, but if YOLO could migrate in this direction somehow I believe it would resolve the grid issue nicely.
  2. I tried GIoU loss, and observed similar results to the default xywh loss, but it may still be a useful replacement because it stabilizes the box loss, which currently suffers from wh loss divergences on occasion (a lot of users complain on my repo that wh diverges when training on their custom datasets). The wh divergence occurs because it is the only unbounded YOLO loss, and in my repo without the burnin phase wh loss diverges every time. Also a nice feature of the GIoU loss is that it is independent of the box space, i.e. you can compute your boxes in xywh, lrtb, etc. and still use GIoU loss.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 19, 2019

@glenn-jocher

So by using ultralytics/yolov3 you can't achieve the same [email protected] / [email protected] as by using AlexeyAB/darknet or https://github.com/pjreddie/darknet ?


Did you try to compare [email protected] / [email protected] by training on ultralytics/yolov3 with and without Logistic-scale scale_x_y = 1.05, 1.1, 1.2, does it give higher [email protected] than your current maximum 0.464?


For ignore_threshold you just shouldn't decrease objectness (T0) if IoU(detection, truth) > ignore_threshold:

darknet/src/yolo_layer.c

Lines 291 to 296 in 8c80ba6

int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
avg_anyobj += l.output[obj_index];
l.delta[obj_index] = l.cls_normalizer * (0 - l.output[obj_index]);
if (best_iou > l.ignore_thresh) {
l.delta[obj_index] = 0;
}


You should use burn_in=1000 for the first 1000 iterations, since at the begining of training, the most of cell don't have objects, so it leads to extreamely high loss and training is unstable:

if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);


Do you think that we should use (lrtb) instead of (xywh)? Or may be better to use Corners (lr)(tb) as in CornerNet? #3229

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 21, 2019

@AlexeyAB I haven't tried to fully train COCO from darknet53.conv.74 on AlexeyAB/darknet, but I've trained for a couple epochs with good results. I've trained ultralytics/yolov3 a few times, but so far not reaching the same performance. I think this is simply my fault for not duplicating the AlexeyAB/darknet loss function correctly yet. I will put a comparison in a new issue to clear up the differences.

One big question I had was regarding the total training time. How many full passes through the COCO2014 training set of 117264 images are done to reach 500200 batches assuming the cfg here? At first I thought it was 273, then more recently I came to believe it was 68.25, but I'm still confused.

# Training
# batch=64
# subdivisions=16

About the Logistic-scale scale_x_y = 1.05, 1.1, 1.2, I have not tried to recreate it yet, because there are still other unresolved issues in my loss function implementation compared to AlexeyAB/darknet I think. I can describe to you the implementation here:

  1. Images and targets are augmented and multi-scaled.
  2. Every target is matched to the most similar anchor at each YOLO layer (using width-height IoU) and assigned to a grid cell. This means that multiple targets in a grid cell may match the same anchor in that grid cell. More than one anchor may not be assigned to a target however.
  3. An iou_threshold hyperparameter deletes matches below threshold (hyp['iou_t'] = 0.29 currently). I implemented this after observing much better performance with it, i.e. mAP after epoch 0 increases from 0.06 assuming hyp['iou_t'] = 0.00 to 0.12 assuming hyp['iou_t'] = 0.29. I need to replace (or add) this with your ignore_thres here. The purpose of this is to prevent say large objects from causing gradient updates in the smallest grid cells (i.e. to further specialize yolo layers by object size). It might be nicer to implement this transition is a smoother fashion however than a binary pass or no-pass cut.
    https://github.com/ultralytics/yolov3/blob/1a9aa30efcf7cf17b79736542a8c3f77c03d4854/utils/utils.py#L318-L335
# reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True
if reject:
    j = iou > iou_thres
    t, a, gwh = targets[j], a[j], gwh[j]   # targets, anchors, wh
  1. Surviving matches are passed to the loss function. xy losses are grouped into MSE loss, wh losses are grouped into another MSE loss, all object confidences are passed into one BCE loss (there is no obj, noobj separation, but there is a positive weighting which is effectively similar, set at 3.53), classifications passed to a CE loss (which performed better than BCE loss).
    https://github.com/ultralytics/yolov3/blob/1a9aa30efcf7cf17b79736542a8c3f77c03d4854/utils/utils.py#L297-L307
if giou_loss:
    pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1)  # predicted
    giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation
    lxy += (k * h['giou']) * (1.0 - giou).mean()  # giou loss
else:
    lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss
    lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss
            
lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf)  # obj_conf loss
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # class_conf loss

loss = lxy + lwh + lconf + lcls
  1. All loss terms are element-averaged (i.e. summed and then divided by the number of elements, not simply summed), and have gains applied to them based on a hyperparameter search. The hyperparameters currently in place are here. k above is the number of images in the batch.
hyp = {'giou': .035,  # giou loss gain
       'xy': 0.20,  # xy loss gain
       'wh': 0.10,  # wh loss gain
       'cls': 0.035,  # cls loss gain
       'conf': 1.61,  # conf loss gain
       'conf_bpw': 3.53,  # conf BCELoss positive_weight
       'iou_t': 0.29,  # iou target-anchor training threshold
       'lr0': 0.001,  # initial learning rate
       'momentum': 0.90,  # SGD momentum
       'weight_decay': 0.0005}  # optimizer weight decay
  1. Gradients are accumulated until 64 images have been processed, then a SGD update is applied.

Burn-in is correctly implemented in ultralytics/yolov3 I believe. In this plot one batch is really a minibatch of 16 images, corresponding to batch = 64, subdivisions = 4 (or batch = 64, subdivisions = 16?) in the *.cfg files. This burnin LR applies to all model parameters during epoch 0 only.
ultralytics/yolov3#238 (comment)
LR

Wow, the cornernet-squeeze and cornernet-saccade results are super impressive! I was not aware of those. I'll have to read the paper.

@AlexeyAB
Copy link
Owner

@glenn-jocher

Images and targets are augmented and multi-scaled.

  1. For resizing do you use: Resizing : keeping aspect ratio, or not #232 (comment)
  • Does it keep aspect ration (letter_box) as it is done in pjreddie-repo
  • or does it just resizes image to network size, as it is done in my repo?

One big question I had was regarding the total training time. How many full passes through the COCO2014 training set of 117264 images are done to reach 500200 batches assuming the cfg here? At first I thought it was 273, then more recently I came to believe it was 68.25, but I'm still confused.

  1. There is batch=64 by default, it means that for each 1 iteration it loads 64 images. So for 500200 batches it loads 64*500200 = 32 012 800 images.
    If MS COCO contains 117 264 images, then it was trained for 32012800 / 117264 = 272 epochs = 272 full passes through the COCO2014 training set.

Every target is matched to the most similar anchor at each YOLO layer (using width-height IoU). This means that multiple targets may match the same anchor (but not vice versa).

This is correct, the same as it is done in this repo.

An iou_threshold hyperparameter deletes matches below threshold (hyp['iou_t'] = 0.29 currently). I implemented this after observing much better performance with it, i.e. mAP after epoch 0 increases from 0.06 to 0.12.

  1. Did you test it for 1 epoch on MS COCO? Try to Test it for more epochs.

classifications passed to a CE loss (which performed better than BCE loss).

  1. Do you mean that Softmax (only one class_id per anchors) better than Logistic (multi-labeling classification)?

@glenn-jocher
Copy link
Author

  1. For resizing do you use: #232 (comment)
  • Does it keep aspect ration (letter_box) as it is done in pjreddie-repo
  • or does it just resizes image to network size, as it is done in my repo?

Oh my goodness. So AlexeyAB/darknet does not letterbox images when resizing? I thought it was important to maintain the aspect ratio of the shapes. Well this is an easy change I can implement and observe the effect after epoch 0.

One big question I had was regarding the total training time. How many full passes through the COCO2014 training set of 117264 images are done to reach 500200 batches assuming the cfg here? At first I thought it was 273, then more recently I came to believe it was 68.25, but I'm still confused.

  1. There is batch=64 by default, it means that for each 1 iteration it loads 64 images. So for 500200 batches it loads 64*500200 = 32 012 800 images.
    If MS COCO contains 117 264 images, then it was trained for 32012800 / 117264 = 272 epochs = 272 full passes through the COCO2014 training set.

Ah ok, got it. That's unfortunate, that means about 1 week of training time on a V100 :(

Every target is matched to the most similar anchor at each YOLO layer (using width-height IoU). This means that multiple targets may match the same anchor (but not vice versa).

This is correct, the same as it is done in this repo.

An iou_threshold hyperparameter deletes matches below threshold (hyp['iou_t'] = 0.29 currently). I implemented this after observing much better performance with it, i.e. mAP after epoch 0 increases from 0.06 to 0.12.

  1. Did you test it for 1 epoch on MS COCO? Try to Test it for more epochs.

Yes this is the crux of the experiments I have been running. I've been doing hyperparameter tuning and other testing using epoch 0 results, under the assumption that an improvement after epoch 0 would also mean an improvement after epoch 273, but its not clear to me thats always the case. How many coco epochs would you say is the minimum to run to experiment with changes?

classifications passed to a CE loss (which performed better than BCE loss).

  1. Do you mean that Softmax (only one class_id per anchors) better than Logistic (multi-labeling classification)?

Yes I've seen that (at least in PyTorch) CE outperforms BCE for single-label classification, not only in YOLOv3 but also on pure classification tasks like MNIST (I tested this myself). PyTorch CE loss " combines log_softmax and nll_loss in a single function." whereas BCE loss is a sigmoid layer combined with binary cross entropy.

I think the different strategies for formulating the regression problem are interesting (lrtb, xywh, cornernet, centernet etc.), and obviously very important, but one of the things I saw that worked best in regression networks seems strangely absent in object detection, which is normalizing the regression network targets to zero mean and unity variance (and then applying calibration curves later during testing and detection to bring the network outputs back to the range you want them). So I think irrespective of what the regression space is (xywh, lrtb etc), we also want the regression targets to produce a statistical distribution that has zero mean and a variance as close to 1.0 as possible.

In shallow regression networks I've worked on in the past, like https://github.com/ultralytics/wave, arxiv, this simple change has had huge impacts on network performance. I think the current anchor system is partly accomplishing this, and these other methods may also deal with the issue in indirect ways, but I think there is great room for improvement left. I don't have a specific recommendation right now, but I think this is an extremely important concept to keep in mind.

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 21, 2019

Ok I've added a scaleFill resize option (I borrowed the term from Apple's vocabulary for resizing options https://developer.apple.com/documentation/vision/vnimagecropandscaleoption). The original letterboxing (scaleFit) is what I've been doing, similar to original darknet apparently. I will test each for one epoch.

scaleFit
train_batch0

scaleFill
train_batch0

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jun 21, 2019

@glenn-jocher

Oh my goodness. So AlexeyAB/darknet does not letterbox images when resizing? I thought it was important to maintain the aspect ratio of the shapes. Well this is an easy change I can implement and observe the effect after epoch 0.

Yes, it is improtan for the most competitions.
But in my projects better to use simple resizing, because Training and Test images have the same size in the most cases, so ojbects have large sizes: #232 (comment)

It would be insteresting to see what approach is better scaleFit or scaleFil.


How many coco epochs would you say is the minimum to run to experiment with changes?

I think 1 epoch for preliminary tests is enough.
I'm just not sure that there will always be the same effect for 200 epochs.


Yes I've seen that (at least in PyTorch) CE outperforms BCE for single-label classification, not only in YOLOv3 but also on pure classification tasks like MNIST (I tested this myself). PyTorch CE loss " combines log_softmax and nll_loss in a single function." whereas BCE loss is a sigmoid layer combined with binary cross entropy.

BCE (sigmoid for multi-label classification) was introduced mainly for OpenImages-dataset, where are many objects are placed close to each other, so one anchors can detect several objects - it increases accuracy. Joseph tried to make one universal optimal model for the most datasets. Even if it is at the expense of (decrease) accuracy for other datasets.


I think the different strategies for formulating the regression problem are interesting (lrtb, xywh, cornernet, centernet etc.), and obviously very important, but one of the things I saw that worked best in regression networks seems strangely absent in object detection, which is normalizing the regression network targets to zero mean and unity variance (and then applying calibration curves later during testing and detection to bring the network outputs back to the range you want them). So I think irrespective of what the regression space is (xywh, lrtb etc), we also want the regression targets to produce a statistical distribution that has zero mean and a variance as close to 1.0 as possible.

What is the difference between this approach and Batch-normalization?
If you can describe this with more recommendations and details in a separate issue, I will add it to the road map for thought. Or try to implement it and test it in ultralytics/yolov3

Batch normalization Before and After:
BatchNorm

@AlexeyAB
Copy link
Owner

Oh my goodness. So AlexeyAB/darknet does not letterbox images when resizing?

I added param letter_box that can be used in the [net] section in cfg-file to train with keeping aspect ratio: c9129c2

[net]
letter_box=1

@glenn-jocher
Copy link
Author

glenn-jocher commented Jun 23, 2019

@AlexeyAB I tested out a few changes, but did not observe improvements. The results are a bit hard to read because the effects of the changes are getting lost in the noise I think. These are all after 1 epoch using img-size 320, batch 64.

# Default training command: python3 train.py --data data/coco.data --img-size 320 --single-scale --batch-size 64 --accumulate 1 --epochs 1  # 0.449hr FP32 P100, 0.279/0.324hr V100 FP32/FP16
#   P         R       mAP        F1
0.111     0.268     0.122     0.144  # default
0.087     0.281     0.109     0.121  # default mixed precision with nvidia apex
0.131     0.261     0.119     0.157  # scaleFill 
0.110     0.285     0.129     0.140  # scale_xy 1.2 
0.104     0.276     0.123     0.141  # scale_xy 1.5
0.109     0.286     0.124     0.132  # scale_xy 2.0
0.053     0.229     0.064    0.0768  # iou threshold = 0.0
0.114      0.28     0.125     0.139  # giou ** 2

To study the regression targets I ran 1 epoch and collected all the values that are actually passed to the MSE loss function as ground truths (including augmentation etc). The results do look pretty normalized. Clearly the anchors are doing a good job of centering the wh targets about zero. I'll have to think about this more, as in my example the network did not use an activation function on the output layer before computing losses.

In my implementation YOLOv3 xy loss is computed after sigmoiding. The predictions range from 0-1 as do the targets. Is this how you do it? Also did you have any luck with giou? I like giou because it only needs one loss-balancing hyperparameter for the box, rather than the two I have now, one for xy, one for wh.
targets

@AlexeyAB
Copy link
Owner

@glenn-jocher

I tested out a few changes, but did not observe improvements.


In my implementation YOLOv3 xy loss is computed after sigmoiding. The predictions range from 0-1 as do the targets. Is this how you do it?

Yes, sure, both Delta and summarized-loss are calculated after sigmoid (for: x,y, objectness, all classes):

darknet/src/yolo_layer.c

Lines 244 to 247 in 5ec3592

activate_array(l.output + index, 2 * l.w*l.h, LOGISTIC); // x,y,
scal_add_cpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output + index, 1); // scale x,y
index = entry_index(l, b, n*l.w*l.h, 4);
activate_array(l.output + index, (1 + l.classes)*l.w*l.h, LOGISTIC);


I didn't train yolov3 on MS COCO with giou, I just checked trained model from https://github.com/generalized-iou/g-darknet , and it gives good mAP@75 and [email protected]
#3249 (comment)


I like giou because it only needs one loss-balancing hyperparameter for the box, rather than the two I have now, one for xy, one for wh.

Can you clarify a little bit more? I don't understand, do you mean scale_xy as hyperparameter for xy, and what hyperparameter do you use for wh?

@glenn-jocher
Copy link
Author

@AlexeyAB I'm not sure that the differences between the runs are statistically reliable. Let me think. I can reproduce the same results on the same hardware+environment, so as long as the experiments don't change the number (or order) of the random numbers generated between cuda, pytorch, numpy and python (all seeds are set to 0 before each training), then the results should be directly comparable. Hmm, ok so yes, then your observations are valid. Then scale_xy 1.2 and giou both seemed to help a little, this is true.

Note that I was lazy in my scale_xy, I used 1.2 in all 3 layers.

Another surprise I found was that mixed precision training (using https://github.com/NVIDIA/apex) worsened the results significantly to 0.109 mAP. I updated my little table in my previous comment with this. All the other results use full FP32, the PyTorch default.

The loss-balancing hyperparameters I was refering to are defined here. I have to use these to balance out the contribution from each loss term, i.e. total loss = 0.2*xy + 0.1*wh + 0.035*cls + 1.61*conf with the current hyperparameters. If I do loss = 1.0*xy + 1.0*wh + 1.0*cls + 1.0*conf mAP stays at 0.0 I think (I will test out again today to be sure). I found these values through a combination of manual tuning and genetic search algorithm (using epoch 0 results as the basis for the tuning).

https://github.com/ultralytics/yolov3/blob/d208f006a11126986d0f6c069200f429b7260886/train.py#L14

hyp = {'giou': .035,  # giou loss gain
       'xy': 0.20,  # xy loss gain
       'wh': 0.10,  # wh loss gain
       'cls': 0.035,  # cls loss gain
       'conf': 1.61,  # conf loss gain
       'conf_bpw': 3.53,  # conf BCELoss positive_weight
       'iou_t': 0.29,  # iou target-anchor training threshold
       'lr0': 0.001,  # initial learning rate
       'momentum': 0.90,  # SGD momentum
       'weight_decay': 0.0005}  # optimizer weight decay

@glenn-jocher
Copy link
Author

@AlexeyAB I was thinking about this topic today, because now I'm not really sure if the dark regions in the xy histogram are caused by simple xy prediction misalignments (which your scale_x_y parameter aims to help), or by actual failures in detection.

Failures in detection would much more explain blinking objects video. The 'handoff' between one grid point and another for a moving object may be failing at the boundary. What do you think?

A partial solution is of course multi-scale inference as we already do with FPN, but for small objects this will not help, as they will only pair with P3 anchors, and coincidentally small objects show the worst COCO mAP.

In this video you can see the motorcycles blink significantly, and the cars almost not at all. But I can't tell how much of this is simply due to the cars having more pixels, or whether overlapping p3,p4,p5 anchors helps them as well avoid blinking:
https://www.youtube.com/watch?v=iUpZagGyhlM

@AlexeyAB
Copy link
Owner

AlexeyAB commented May 2, 2020

@glenn-jocher

I was thinking about this topic today, because now I'm not really sure if the dark regions in the xy histogram are caused by simple xy prediction misalignments (which your scale_x_y parameter aims to help), or by actual failures in detection.

Do you use scale_x_y= in your repo? Can you show 2D Histogram over 1 hour as in the first your message here, for models with scale_x_y= and without it?


Failures in detection would much more explain blinking objects video. The 'handoff' between one grid point and another for a moving object may be failing at the boundary. What do you think?

Yes, sure. There are at least two 'handoff's:

  1. between two grid cells - we try to avoid it by using scale_x_y=
  2. between two anchors - we try to avoid it by using many anchors for 1 truth iou_thresh=0.213

A partial solution is of course multi-scale inference as we already do with FPN, but for small objects this will not help, as they will only pair with P3 anchors, and coincidentally small objects show the worst COCO mAP.

  • why "only"? Do you mean small objects which are 4x smaller than the smallst P4-anchor? So such objects will be process by P3 even if we will use scales 0.25, 0.5, 1, 2, 4 then small objects

  • multi-scale test-time data-augmentation during inference will greatly reduce FPS

  • I think better solution for AP/FPS is to use higher network resolution instead of multi-scale inference. And better to use both higher network resolution + P6.


In this video you can see the motorcycles blink significantly, and the cars almost not at all. But I can't tell how much of this is simply due to the cars having more pixels, or whether overlapping p3,p4,p5 anchors helps them as well avoid blinking:
https://www.youtube.com/watch?v=iUpZagGyhlM

May be this is "blinking small bike" due to iou_thresh=0.213, which helps detection in some cases and it harms detection in other cases - more specifically iou_thresh=0.213:

  • it (many anchors to one truth) helps transfer an object from one anchor to another anchor
  • it (many truths to one anchor) harms when many small objects are close to each other, f.e., there are 3 cars and 1 bike are very close to each other with IoU>0.213, so they will use 1 anchor, so average probabiliy will be a car for all 4 objects (3 cars and 1 bike)

@glenn-jocher
Copy link
Author

@AlexeyAB yes you are right, multi-scale inference is only useful for paper metrics, competition etc. Yes there are these two handoffs, but the anchor handoff enjoys significant overlap, whereas the grid cell handoff does not overlap in many cases.

The way I understand scale_x_y is that 1 object is still assigned only to 1 grid cell (per output layer), but the box center prediction is allowed to range from say -0.1 to 1.1, which alleviates the problem of trying to predict a box center very near a grid edge. But I don't think this solves the objectness handoff problem. I think what's happening is that objectness is understandably dropping at the grid cell boundaries, falling below inference detection threshold, thus the 'blinking'. I'm planning on doing a study on the video to test this hypothesis, i.e. to plot confidence vs box center for example. This would be natural, as we are asking for 1.0 objectness up untill the boundary, and then we are asking for 0.0 objectness right after the boundary. In practice the predictions will move smoothly from 1.0 to 0.0 near the boundary, not suddenly.

And yes by 'only' I mean that P4 objects may enjoy grid overlap with P3 and P5, partly eliminating the grid issue for medium objects, but tiny P3 objects, smaller than 4x the smallest P4 anchor do not enjoy any help from the P4 grid. In any case though, even if P3 objects did match to P4 anchors, 50% of the P3 grid boundaries are still also P4 grid boundaries, and the problem would remain in those cases.

I think a solution may be to allow both the box regressions to vary past the boundaries a little (as scale_x_y already does), but also to allow objectness to remain high even if the object is a little past the cell boundaries. Then there would be true objectness overlap. I don't think you are doing this already are you?

@glenn-jocher
Copy link
Author

glenn-jocher commented May 3, 2020

@AlexeyAB one really interesting thought I had for solving the objectness handoff would be if even number output layers (P4, P6) are shifted by 0.5 grid cells in both x and y. Then output layers near each other share no grid borders (i.e. P3-P4, and P4-P5 share no more borders), and perhaps then the problem would be much reduced. Implementation would be tricky though, and depend on linear interpolation to shift the image +0.5 grid points in xy on affected P layers (i.e. P4 and P6), and then shift all P4, P6 predictions -0.5 points in xy.

@AlexeyAB
Copy link
Owner

AlexeyAB commented May 3, 2020

@glenn-jocher

Yes there are these two handoffs, but the anchor handoff enjoys significant overlap, whereas the grid cell handoff does not overlap in many cases.

Yes, in the current iou_thresh>0.213 implementation only several anchors from 1 cell can be used for one truth, not several anchors from several cells.

  1. Do you want to try use several anchors from several cells for one truth?

This would be natural, as we are asking for 1.0 objectness up untill the boundary, and then we are asking for 0.0 objectness right after the boundary. In practice the predictions will move smoothly from 1.0 to 0.0 near the boundary, not suddenly.

Yes.
2. Do you want try to use something like?
I.e.

// old code: x,y = output of conv layer for x,y
x_temp = sigomid(x)*scale_x_y - (scale_x_y - 1) / 2;
y_temp = sigomid(y)*scale_x_y - (scale_x_y - 1) / 2;
x_real = (i + x_temp) * w;
y_real = (y + y_temp) * h;

// + new code
// objectness_truth = 1;
// class_prob_truth = 1;
x_d = (sigmoid(x)*(1-sigmoid(x))*scale_x_y;
y_d = (sigmoid(y)*(1-sigmoid(y))*scale_x_y;
objectness_truth = x_d * y_d; // instead of 1
class_prob_truth = x_d * y_d; // instead of 1

I think a solution may be to allow both the box regressions to vary past the boundaries a little (as scale_x_y already does), but also to allow objectness to remain high even if the object is a little past the cell boundaries. Then there would be true objectness overlap. I don't think you are doing this already are you?

  1. Do you want try to use 1 the same mask for P3->P4, P4->P5, P5->P6?
    I used in some models on small datasets and it helps - something like:
[yolo]
masks=0,1,2

[yolo]
masks=2,3,4,5

[yolo]
masks=5,6,7,8

@HamsterHuey
Copy link

HamsterHuey commented Jul 23, 2020

I believe the sensitivity effect you guys are seeing here may be somewhat impacted by the anchor boxes, but also could be related to a more fundamental issue with information loss during downsampling steps in the network. You may not have come across this nice paper about antialiasing downsampling at some extra computational cost to make networks more translationally invariant:

https://richzhang.github.io/antialiased-cnns/

This seems to help some though the best solution (but halving frame-rates) is to simply use test-time augmentation with sending a copy of the input image but offset in x and y and running inference across both and then combining declarations.

@LukeAI
Copy link

LukeAI commented Jul 23, 2020

https://richzhang.github.io/antialiased-cnns/

This seems to help some though the best solution (but halving frame-rates) is to simply use test-time augmentation with sending a copy of the input image but offset in x and y and running inference across both and then combining declarations.

This was tried, seemed to give slightly worse results
#3672

@HamsterHuey
Copy link

@LukeAI - Ah thanks. Yes, I also had mixed results implementing it on a different architecture (CenterNet) though the idea behind the paper seems sound and felt like it should have helped.

@AlexeyAB
Copy link
Owner

@HamsterHuey It helps only for small datasets and without shift data augmentation.

@viplix3
Copy link

viplix3 commented Jul 27, 2020

@AlexeyAB, I've been working with a custom object detector based on YOLOv3, and I am facing the same issue. The model is giving good detection with very high confidence when being tested on images, but I've observed a flickering effect when testing is done on video sequences. On further inspection, I observed my model is suffering from similar sensitivity near grid boundaries issue as is being discussed on this thread.

Details regarding model training
Model has been trained using 512x512 image resolution with the letterbox technique as used in the original YOLOv3. The dataset is a custom pedestrian detection dataset, no feature-extractor pre-training is done and the model has been trained from scratch on the said pedestrian detection dataset.

I've gone through the whole thread but I can't seem to figure out how you ended up with the given scale_xy values for different branches.
Please correct me if I am wrong.
As per my understanding, you've modified x_mid_offset, y_mid_offset such that their values fall in a range somewhat like (-0.1, +1.1) (the exact range that you are getting might be different) which should solve the problem of saturation of sigmoid near the grid boundaries because of expanded range.

For achieving this, you used the following modified equation:
x_mid_offset = scale_x * sigmoid(t_x) - gap_factor * (scale_x - 1)
gap_factor = no_detection_region_in_pixels / grid_width_in_pixels

Rather than the original YOLOv3 equation
x_mid_offset = sigmoid(t_x)

I tried the above mentioned modified equation and analyzed the results post-training.
I analyzed the x_mid_offset values for each of the 3 detection branches of a model trained without the modification and one trained with the modification. I calculated the gap between different grid values, but they seem to have deteriorated after doing the said modification.

Model trained using original YOLOv3 equations

  1. Big Object Branch (16, 16, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 3.368145751953125
25th Percentile: 4.016387939453125
50th Percentile: 4.351654052734375
Minimum Gap: 2.9743194580078125
Maximum Gap: 6.04254150390625
Mean Gap: 4.527173105875651
  1. Medium Object Branch (32, 32, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 1.4226531982421875
25th Percentile: 1.6012496948242188
50th Percentile: 1.7520599365234375
Minimum Gap: 0.578338623046875
Maximum Gap: 2.762908935546875
Mean Gap: 1.7731404458322833
  1. Small Object Branch (64, 64, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 0.722906494140625
25th Percentile: 0.8274993896484375
50th Percentile: 0.998382568359375
Minimum Gap: 0.643280029296875
Maximum Gap: 2.018594741821289
Mean Gap: 1.070833168332539

Model trained using modified equations as mentioned in this comment

  1. Big Object Branch (16, 16, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 6.392932128906249
25th Percentile: 6.830898284912109
50th Percentile: 7.550201416015625
Minimum Gap: 5.757049560546875
Maximum Gap: 11.434921264648438
Mean Gap: 7.691924540201823
  1. Medium Object Branch (32, 32, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 2.1123046875
25th Percentile: 2.1699914932250977
50th Percentile: 2.3658294677734375
Minimum Gap: 1.945770263671875
Maximum Gap: 3.4285125732421875
Mean Gap: 2.420973562425183
  1. Small Object Branch (64, 64, 18)
**** No detection region/Grid gap summary (values in pixels) ****
10th Percentile: 1.1284088134765624
25th Percentile: 1.3569717407226562
50th Percentile: 1.5756683349609375
Minimum Gap: 0.618408203125
Maximum Gap: 2.300570487976074
Mean Gap: 1.554993160187252

Can you shed more light on how did you select the scale_x_y values?
If required, I can provide other details like histograms, scatter plots, etc.

@AlexeyAB
Copy link
Owner

@viplix3
Do you use this repository for training and detection? https://github.com/AlexeyAB/darkne

Attach both cfg-files.

If required, I can provide other details like histograms, scatter plots, etc.

Yes.

@viplix3
Copy link

viplix3 commented Jul 27, 2020

@AlexeyAB

Do you use this repository for training and detection? https://github.com/AlexeyAB/darkne

Attach both cfg-files.

Sorry, I didn't use https://github.com/AlexeyAB/darknet, so cannot provide any cfg-files.
I used TensorFlow code for training and detection which I've written myself with the help of some other opensource repositories.

Is my understanding correct regarding the modified equations?

For achieving this, you used the following modified equation:
x_mid_offset = scale_x * sigmoid(t_x) - gap_factor * (scale_x - 1)
gap_factor = no_detection_region_in_pixels / grid_width_in_pixels

I'm providing the relevant figures below. All the prediction are done on same test set with a confidence threshold of 0.1 and NMS threshold of 0.3

Model trained with original YOLOv3 offset equation

Big Object Branch (16, 16, 18) x_mid_offset histograms on 512x512 resolution
y_offset_histogram_big_object_branch

Medium Object Branch (32, 32, 18) x_mid_offset histograms on 512x512 resolution
x_offset_histogram_medium_object_branch

Small Object Branch (64, 64, 18) x_mid_offset histograms on 512x512 resolution
x_offset_histogram_small_object_branch

Big Object Branch (16, 16, 18) x_mid_offset scatter plot on 512x512 resolution
big_object_branch_xy_offset

Medium Object Branch (32, 32, 18) x_mid_offset scatter plot on 512x512 resolution
medium_object_branch_xy_offset

Small Object Branch (64, 64, 18) x_mid_offset scatter plot on 512x512 resolution
small_object_branch_xy_offset

x_mid_offset scatter plot normalized and combined
all_object_branch_xy_offset_combined

Model trained with modified offset equations

Big Object Branch (16, 16, 18) x_mid_offset histograms on 512x512 resolution
y_offset_histogram_big_object_branch

Medium Object Branch (32, 32, 18) x_mid_offset histograms on 512x512 resolution
x_offset_histogram_medium_object_branch

Small Object Branch (64, 64, 18) x_mid_offset histograms on 512x512 resolution
x_offset_histogram_small_object_branch

Big Object Branch (16, 16, 18) x_mid_offset scatter plot on 512x512 resolution
big_object_branch_xy_offset

Medium Object Branch (32, 32, 18) x_mid_offset scatter plot on 512x512 resolution
medium_object_branch_xy_offset

Small Object Branch (64, 64, 18) x_mid_offset scatter plot on 512x512 resolution
small_object_branch_xy_offset

x_mid_offset scatter plot normalized and combined
all_object_branch_xy_offset_combined

Please note that the prediction scatter plot shape is like a parabola because the GT distribution is like that. It has nothing to do with incorrect model training.

It is pretty evident from the histograms and scatter plots that model is not able to predcit boxes near the grid boundaries.

@AlexeyAB
Copy link
Owner

It seems something wrong with your code.

@viplix3
Copy link

viplix3 commented Jul 27, 2020

It seems something wrong with your code.

I can assure you my code is fine as the model trained using the said code has been tested exhaustively on over 100k frames for model detection performance and any unusually wrong detections haven't been observed so far.

@glenn-jocher
Copy link
Author

@viplix3 the general idea is that it is impossible to generate an output of 0 or 1 from a sigmoid as the input neuron would need to be outputting -inf or inf. This is the discovery I made and the effect you are seeing in your plots.

The solution is to expand the output space past 0-1 (i.e. -0.2 to 1.2) while retaining the targets to a smaller space, allowing model outputs to more easily spread across the grid. The specifics of how you do this probably don't matter much, so you should experiment and see what works best for your experiment.

If you arrive at any innovative solutions please update here!

@viplix3
Copy link

viplix3 commented Jul 28, 2020

@glenn-jocher Thanks for the clarification.
I was trying to understand how @AlexeyAB did it. But as per your explanation, there are no set rules for solving this problem.
I will be training a model in which I'll use sigmoid with a gain factor of 0.8 for (x_mid, y_mid) offset predictions. This should relax the domain -> range saturation of sigmoid near the 0 and 1 values, which in turn should help in narrowing the sensitivity regions near the grid boundaries.

@glenn-jocher
Copy link
Author

@viplix3 sure, there's all sorts of creative ways you could circumvent this or mitigate it. You can expand your sigmoid scaling to stay away from the edges, or for example you could use an FCOS style box regression, which doesn't even require the model to output a centerpoint. That would completely negate the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

No branches or pull requests

7 participants