Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LEARNING RATE SCHEDULER #238

Closed
glenn-jocher opened this issue Apr 24, 2019 · 56 comments
Closed

LEARNING RATE SCHEDULER #238

glenn-jocher opened this issue Apr 24, 2019 · 56 comments
Assignees
Labels
question Further information is requested Stale tutorial Tutorial or example

Comments

@glenn-jocher
Copy link
Member

glenn-jocher commented Apr 24, 2019

The original darknet learning rate (LR) scheduler parameters are set in a model's *.cfg file:

  • learning_rate: initial LR
  • burn_in: number of batches to ramp LR from 0 to learning_rate in epoch 0
  • max_batches: the number of batches to train the model to
  • policy: type of LR scheduler
  • steps: batch numbers at which LR is reduced
  • scales: LR multiple applied at steps (gamma in PyTorch)

Screenshot 2019-04-24 at 12 38 18

In this repo LR scheduling is set in train.py. We set the initial and final LRs as hyperparameters hyp['lr0'] and hyp['lrf'], where the final LR = lr0 * (10 ** lrf) . For example, if the initial LR is 0.001 and the final LR is 100 times (1e-2) smaller, hyp['lrf']=0.001 and hyp['lrf']=-2. This plot shows two of the available PyTorch LR schedulers, with the MultiStepLR scheduler following the original darknet implementation (at batch_size=64 on COCO). To learn more please visit:
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

LR

The LR hyperparameters are tunable, along with all the rest of the model hyperparmeters in train.py:

yolov3/train.py

Lines 13 to 25 in 1771ffb

# Hyperparameters
# 0.861 0.956 0.936 0.897 1.51 10.39 0.1367 0.01057 0.01181 0.8409 0.1287 0.001028 -3.441 0.9127 0.0004841
hyp = {'k': 10.39, # loss multiple
'xy': 0.1367, # xy loss fraction
'wh': 0.01057, # wh loss fraction
'cls': 0.01181, # cls loss fraction
'conf': 0.8409, # conf loss fraction
'iou_t': 0.1287, # iou target-anchor training threshold
'lr0': 0.001028, # initial learning rate
'lrf': -3.441, # final learning rate = lr0 * (10 ** lrf)
'momentum': 0.9127, # SGD momentum
'weight_decay': 0.0004841, # optimizer weight decay
}

Actual LR scheduling is set further down in train.py, and has been tuned for COCO training. You may want to set your own scheduler according to your specific custom dataset and training requirements, and also adjust it's hyperparameters accordingly.

yolov3/train.py

Lines 102 to 109 in bd2378f

# Scheduler https://github.com/ultralytics/yolov3/issues/238
# lf = lambda x: 1 - x / epochs # linear ramp to zero
# lf = lambda x: 10 ** (hyp['lrf'] * x / epochs) # exp ramp
lf = lambda x: 1 - 10 ** (hyp['lrf'] * (1 - x / epochs)) # inverse exp ramp
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf, last_epoch=start_epoch - 1)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[218, 245], gamma=0.1, last_epoch=start_epoch-1)

@glenn-jocher glenn-jocher added question Further information is requested tutorial Tutorial or example labels Apr 24, 2019
@glenn-jocher glenn-jocher self-assigned this Apr 29, 2019
@announce1
Copy link

Hello glenn-jocher
Sorry to interrupt you .
But I am confusing about the following learning rate decay function .
lf = lambda x: 1 - 10 ** (hyp['lrf'] * (1 - x / epochs))
How did you get this function ?
Thank you for reply

@glenn-jocher
Copy link
Member Author

@announce1 this equation corresponds to the orange and green curves above. It's an inverse exponential that decays the LR to zero by the value epochs. This type of curve is tunable with a single hyperparameter, hyp['lrf']. We selected it due to this simple tunability, it's continuous nature, and from good performance comparisons vs other functions, such as linear, exponential, and steps.

@announce1
Copy link

Hi glenn-jocher
Thank you for your reply .
I had got your mind in tuning this function and these hpyer . Thank you very much .
What also confuse me is where did you get this function ?
From other people's paper or other loss-calculate model ?
Thanks a lot .

@glenn-jocher
Copy link
Member Author

@announce1 the function is a simple exponential, its very common in statistics.

@abhinav3
Copy link

abhinav3 commented Jun 6, 2019

Hi glenn-jocher,
I'm trying to train it on my custom dataset containing roughly 160K images and around 50 classes with batch size = 48.
I was using the default values of hyper-parameter as mentioned in the train.py file.
which are
hyp = {'xy': 0.2, # xy loss gain 'wh': 0.1, # wh loss gain 'cls': 0.04, # cls loss gain 'conf': 4.5, # conf loss gain 'iou_t': 0.5, # iou target-anchor training threshold 'lr0': 0.001, # initial learning rate 'lrf': -4., # final learning rate = lr0 * (10 ** lrf) 'momentum': 0.90, # SGD momentum 'weight_decay': 0.0005}.

During the first epoch itself, I got 'WARNING: nan loss detected, ending training' message.

I've trained yolov2 on similar dataset with learning_rate = 0.00001 and it converged quite well with out using any momentum etc.

So, If I use 'lr0': 0.00001, # initial learning rate with your code, then what are the other hyper parameter values you suggest, like the values for momentum, lrf and weight_decay etc.

Thanks

@glenn-jocher
Copy link
Member Author

@abhinav3 your wh loss is likely diverging. See #307
You should probably lower your hyp['wh'] gain, or increase your burnin period.

@glenn-jocher
Copy link
Member Author

LR burnin (burn-in) during first 1000 batches:

yolov3/train.py

Lines 221 to 226 in 1a9aa30

# SGD burn-in
if epoch == 0 and i <= n_burnin:
lr = hyp['lr0'] * (i / n_burnin) ** 4
for x in optimizer.param_groups:
x['lr'] = lr

LR

@H-YunHui
Copy link

@glenn-jocher
I have a question about the epochs, when I set different epochs, the Ir of the previous epochs processes is the same, why are AP's so different for the same IR (I have only one class)

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Aug 29, 2019

@www12345678 Hello, thank you for your interest in our work! Please note that most technical problems are due to:

  • Your changes to the default repository. If your issue is not reproducible in a fresh git clone version of this repository we can not debug it. Before going further run this code and ensure your issue persists:
sudo rm -rf yolov3  # remove exising repo
git clone https://github.com/ultralytics/yolov3 && cd yolov3 # git clone latest
python3 detect.py  # verify detection
python3 train.py  # verify training (a few batches only)
# CODE TO REPRODUCE YOUR ISSUE HERE
  • Your custom data. If your issue is not reproducible with COCO data we can not debug it. Visit our Custom Training Tutorial for exact details on how to format your custom data. Examine train_batch0.jpg and test_batch0.jpg for a sanity check of training and testing data.
  • Your environment. If your issue is not reproducible in a GCP Quickstart Guide VM we can not debug it. Ensure you meet the requirements specified in the README: Unix, MacOS, or Windows with Python >= 3.7, Pytorch >= 1.1, etc. You can also use our Google Colab Notebook to test your code in working environment.

If none of these apply to you, we suggest you close this issue and raise a new one using the Bug Report template, providing screenshots and minimum viable code to reproduce your issue. Thank you!

@mozpp
Copy link

mozpp commented Nov 21, 2019

yolov3/train.py

Lines 254 to 263 in 74b5750

# Hyperparameter burn-in
# n_burn = nb - 1 # min(nb // 5 + 1, 1000) # number of burn-in batches
# if ni <= n_burn:
# for m in model.named_modules():
# if m[0].endswith('BatchNorm2d'):
# m[1].momentum = 1 - i / n_burn * 0.99 # BatchNorm2d momentum falls from 1 - 0.01
# g = (i / n_burn) ** 4 # gain rises from 0 - 1
# for x in optimizer.param_groups:
# x['lr'] = hyp['lr0'] * g
# x['weight_decay'] = hyp['weight_decay'] * g

Why "burn in" is disabled now?

@glenn-jocher
Copy link
Member Author

@mozpp burn-in is unneeded anymore. GIoU stabilizes the unbounded wh loss.

@mozpp
Copy link

mozpp commented Nov 21, 2019

@mozpp burn-in is unneeded anymore. GIoU stabilizes the unbounded wh loss.

Could you explain how does "prebias" work? I think it is similar to "burn in".

@glenn-jocher
Copy link
Member Author

@mozpp no, prebias attempts to aggressively optimize neuron biases on Conv2d() layes preceding each YOLO layer. There are only 765 of these in yolov3-spp, the rest of the network is frozen and unaffected. There is no relation with burnin. Burnin is reduced LR in initial batches.

@yujianll
Copy link

yujianll commented Dec 5, 2019

@glenn-jocher Sorry I'm still confused about what prebias does.
I check the full model_info under prebias mode, it seems all parameters are set to gradient = True. Does it copy the weights from pre-trained model and fine tune on all parameters?

@glenn-jocher
Copy link
Member Author

@yujianll see #460

@yujianll
Copy link

yujianll commented Dec 5, 2019

@glenn-jocher Thanks!
Do you have any suggestion about avoiding GPU memory issue?
I'm using yolov3-spp model with 13 classes, batch size 32, prebias, and no transfer learning. I encountered CUDA out of memory in training step.

It seems simply reducing batch size fix it.

@glenn-jocher
Copy link
Member Author

@yujianll reduce --batch-size

@developer0hye
Copy link
Contributor

@yujianll
use group normalization instead of batch normalization, you can train the model with small batch size

@yujianll
Copy link

@developer0hye Thanks! Is there an argument that I can set to use group normalization?

@developer0hye
Copy link
Contributor

developer0hye commented Dec 10, 2019

@yujianll
There is no argument for it... That's a shame... We need to implement it! But, it is easy to implement.

@glenn-jocher
Do you have the plan for using group normalization??

@FranciscoReveriano
Copy link
Contributor

This would be a great addition!

@developer0hye
Copy link
Contributor

@FranciscoReveriano
Yeah, I think so!

@glenn-jocher
Copy link
Member Author

@developer0hye @FranciscoReveriano my understanding of groupnorm is that yes it can approach the results of batchnorm with smaller batch sizes, but not exceed them. I don't have any plans for it at the moment, but if you implement and compare results on coco_64img.data it might be a worthwhile PR.

@glenn-jocher
Copy link
Member Author

@FranciscoReveriano si parece ser superior el nuevo cosine scheduler. Ayuda en la mitad de training, y tambien ayuda un poco el mAP final. Voy a probarlo en un full training de full COCO esta semana, y si tambien ayuda ayi lo convertire en el default scheduler. :)

@glenn-jocher
Copy link
Member Author

Cosine LR scheduler is the new default. See #238 (comment)

yolov3/train.py

Lines 144 to 145 in 84371f6

lf = lambda x: 0.5 * (1 + math.cos(x * math.pi / epochs)) # cosine https://arxiv.org/pdf/1812.01187.pdf
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

@FranciscoReveriano
Copy link
Contributor

@glenn-jocher I saw that you also made a declining Pre-bias

@glenn-jocher
Copy link
Member Author

@FranciscoReveriano yes, I updated it to vary smoothly from the initial prebias conditions (lr and momentum) at epoch 0, to the normal conditions over 3 epochs. I think this might help a bit.

@FranciscoReveriano
Copy link
Contributor

@glenn-jocher Yo eh visto que cuando hago un prebias de una magnitude mas alta ah la initial me ayuda con los gradients.

@developer0hye
Copy link
Contributor

developer0hye commented Mar 6, 2020

@glenn-jocher
Good works!

I have questions, actually pytorch already has their own cosine scheduler, but why do you re-implement this?

Learning Rate warmup with cosine scheduler

It's my own implementation for learning rate warmup with pytorch official cosine lr scheduler!

Have you ever tried to apply learning rate warmup for stable training in early epoch?

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Mar 6, 2020

@developer0hye yes this is true, perhaps we should use a warmup period, or 'burnin' as some people call it.

I'm a bit confused about how to handle the initial iterations because right now we have a prebias period where we actually use much higher LRs for the model biases (weight LRs stay the same) for the first few epochs, which is somewhat opposite of a warmup:

yolov3/train.py

Lines 216 to 229 in 65eeb1b

# Prebias
if prebias:
ne = max(round(30 / nb), 3) # number of prebias epochs
ps = np.interp(epoch, [0, ne], [0.1, hyp['lr0'] * 2]), \
np.interp(epoch, [0, ne], [0.9, hyp['momentum']]) # prebias settings (lr=0.1, momentum=0.9)
if epoch == ne:
print_model_biases(model)
prebias = False
# Bias optimizer settings
optimizer.param_groups[2]['lr'] = ps[0]
if optimizer.param_groups[2].get('momentum') is not None: # for SGD but not Adam
optimizer.param_groups[2]['momentum'] = ps[1]

The main problem is that particularly for the yolo output biases, the obj and classification biases should be extremely negative to reflect a very low chance of being predicted. For example, for the classification on COCO, the mean output bias should be math.log(1 / (80 - 1)) = -4.6, to reflect a 1/80 probability, but this assumes a well balanced dataset. Class imbalances can skew this significantly, i.e. in COCO most of the objects are people I think, so the bias outputs are severely skewed, causing serious instabilities in early iterations, especially with momentum optimizers.

yolov3/models.py

Lines 88 to 92 in 65eeb1b

# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
try:
p = math.log(1 / (modules.nc - 0.99)) # class probability -> sigmoid(p) = 1/nc
if arc == 'default' or arc == 'Fdefault': # default
b = [-4.5, p] # obj, cls

@FranciscoReveriano
Copy link
Contributor

I am finding that doing a stable larger learning rate for a longer period avoids nan later down the road.

@developer0hye
Copy link
Contributor

developer0hye commented Mar 7, 2020

@glenn-jocher
have you tried the hard negative mining method used in SSD to cover the imbalance problem of objectness?

I think it can be the solution for that to use the hard negative mining method with focal loss.

@FranciscoReveriano
Copy link
Contributor

@developer0hye Do you have a paper? Or article discussing it?

@glenn-jocher
Copy link
Member Author

@developer0hye I've seen focal loss to accelerate the training and produce higher mAPs sooneer, but I found that it also accelerates _over_training and results in a lower final mAPs. I could never get it to help out on COCO.

I think towards the end of training the main challenge is suppressing overtraining (keeping the validation losses from increasing). If we could manage that, then longer training would result in better mAPs.

@developer0hye
Copy link
Contributor

developer0hye commented Mar 7, 2020

@FranciscoReveriano
There are some explanations for the hard negative mining method in official SSD paper!
ssd paper

And... the method (the hard negative mining with focal loss) I mentioned is just my opinion.
Recently, I've implemented my own yolov3 tiny and tried to train the model from scratch on VOC2007 dataset, but I failed to train the model. I am trying to search for the reason for this.

@glenn-jocher
Thank you for your good opinion!
With my experience, the burn-in learning rate warmup for 5 epochs (updated as per iteration)worked well.

@FranciscoReveriano
Copy link
Contributor

@developer0hye Thanks. I am going to look more into it. I tried focal loss on my dataset but did't get very good results. Maybe this will help with Focal Loss.

@developer0hye
Copy link
Contributor

developer0hye commented Mar 7, 2020

@FranciscoReveriano

Thanks, This repository showed the focal loss improvements in the performance of YOLOv3.
But... when I applied Focal loss to my dataset with my own yolov3, it harms the performance.

@glenn-jocher
Copy link
Member Author

@developer0hye label smoothing seems like an easy update, and it seems to help in the repo you pointed to. For the positive labels, I'm assigning the GIoU value to them rather than 1.0, as I thought this would help sort the boxes more by IOU for NMS rather than by confidence alone.

At the lower end though, I assign all negative examples a target of 0.0. We could modify this to 0.1 as in the label smoothing example to see the effect.

@glenn-jocher
Copy link
Member Author

@developer0hye wait I got myself confused. Can we apply labelsmoothing to nn.BCEWithLogitsLoss() as well as nn.CrossEntropyLoss()? In this repo we only use nn.BCEWithLogitsLoss() now for both obj and cls.

@developer0hye
Copy link
Contributor

@glenn-jocher yeah, I think so. We can apply label smoothing to nn.BCEWithLogitsLoss().

@glenn-jocher
Copy link
Member Author

Just a note, per WongKinYiu/CrossStagePartialNetworks#6 (comment) label smoothing should only be applied to class loss, not obj loss.

From https://arxiv.org/pdf/1902.04103.pdf we have this, but they neglect to state the value of epsilon unfortunately.
Screen Shot 2020-03-07 at 10 09 31 PM

@glenn-jocher
Copy link
Member Author

I found in the TensorFlow code two different implementations of label smoothing:

  1. Categorical Cross Entropy:
    y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
  2. Binary Cross Entropy
    y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

where I assume label_smoothing=0.1 is a typical smoothing value

@developer0hye
Copy link
Contributor

@glenn-jocher
I think so. label _smoothing value is set to 0.1 in many repositories that use the method.

@github-actions
Copy link

github-actions bot commented Aug 1, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@glenn-jocher
Copy link
Member Author

@developer0hye yes, a label smoothing value of 0.1 seems to be commonly used across various repositories that apply the label smoothing method. This can be a good starting point for experimentation, and further adjustments can be made based on specific dataset characteristics and performance evaluation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested Stale tutorial Tutorial or example
Projects
None yet
Development

No branches or pull requests

9 participants