You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
9.1 KiB
233 lines
9.1 KiB
# https://github.com/mlcommons/training/blob/e3769c8dcf88cd21e1001dd2f894b40a1513ec5d/image_classification/tensorflow2/lars_optimizer.py
|
|
# changes: don't call lr_t if it's not a schedule
|
|
|
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import tensorflow as tf
|
|
# from tf2_common.training import optimizer_v2modified
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.keras import backend_config
|
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import linalg_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.training import training_ops
|
|
from tensorflow.python.ops import state_ops
|
|
|
|
|
|
# class LARSOptimizer(optimizer_v2modified.OptimizerV2Modified):
|
|
class LARSOptimizer(optimizer_v2.OptimizerV2):
|
|
"""Layer-wise Adaptive Rate Scaling for large batch training.
|
|
|
|
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
|
|
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
|
|
|
|
Implements the LARS learning rate scheme presented in the paper above. This
|
|
optimizer is useful when scaling the batch size to up to 32K without
|
|
significant performance degradation. It is recommended to use the optimizer
|
|
in conjunction with:
|
|
- Gradual learning rate warm-up
|
|
- Linear learning rate scaling
|
|
- Poly rule learning rate decay
|
|
|
|
Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
|
|
use the default momentum optimizer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate,
|
|
momentum=0.9,
|
|
weight_decay=0.0001,
|
|
# The LARS coefficient is a hyperparameter
|
|
eeta=0.001,
|
|
epsilon=0.0,
|
|
name="LARSOptimizer",
|
|
# Enable skipping variables from LARS scaling.
|
|
# TODO(sameerkm): Enable a direct mechanism to pass a
|
|
# subset of variables to the optimizer.
|
|
skip_list=None,
|
|
use_nesterov=False,
|
|
**kwargs):
|
|
"""Construct a new LARS Optimizer.
|
|
|
|
Args:
|
|
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
|
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
|
|
that takes no arguments and returns the actual value to use. The
|
|
learning rate.
|
|
momentum: A floating point value. Momentum hyperparameter.
|
|
weight_decay: A floating point value. Weight decay hyperparameter.
|
|
eeta: LARS coefficient as used in the paper. Dfault set to LARS
|
|
coefficient from the paper. (eeta / weight_decay) determines the highest
|
|
scaling factor in LARS.
|
|
epsilon: Optional epsilon parameter to be set in models that have very
|
|
small gradients. Default set to 0.0.
|
|
name: Optional name prefix for variables and ops created by LARSOptimizer.
|
|
skip_list: List of strings to enable skipping variables from LARS scaling.
|
|
If any of the strings in skip_list is a subset of var.name, variable
|
|
'var' is skipped from LARS scaling. For a typical classification model
|
|
with batch normalization, the skip_list is ['batch_normalization',
|
|
'bias']
|
|
use_nesterov: when set to True, nesterov momentum will be enabled
|
|
**kwargs: keyword arguments.
|
|
|
|
Raises:
|
|
ValueError: If a hyperparameter is set to a non-sensical value.
|
|
"""
|
|
if momentum < 0.0:
|
|
raise ValueError("momentum should be positive: %s" % momentum)
|
|
if weight_decay < 0.0:
|
|
raise ValueError("weight_decay should be positive: %s" % weight_decay)
|
|
super(LARSOptimizer, self).__init__(name=name, **kwargs)
|
|
|
|
self._set_hyper("learning_rate", learning_rate)
|
|
|
|
# When directly using class members, instead of
|
|
# _set_hyper and _get_hyper (such as learning_rate above),
|
|
# the values are fixed after __init(), and not being
|
|
# updated during the training process.
|
|
# This provides better performance but less flexibility.
|
|
self.momentum = momentum
|
|
self.weight_decay = weight_decay
|
|
self.eeta = eeta
|
|
self.epsilon = epsilon or backend_config.epsilon()
|
|
self._skip_list = skip_list
|
|
self.use_nesterov = use_nesterov
|
|
|
|
def _prepare_local(self, var_device, var_dtype, apply_state):
|
|
lr_t = self._get_hyper("learning_rate", var_dtype)
|
|
local_step = math_ops.cast(self.iterations, var_dtype)
|
|
if callable(lr_t): lr_t = math_ops.cast(lr_t(local_step), var_dtype)
|
|
learning_rate_t = array_ops.identity(lr_t)
|
|
|
|
apply_state[(var_device, var_dtype)].update(
|
|
dict(
|
|
learning_rate=learning_rate_t,
|
|
))
|
|
|
|
def _create_slots(self, var_list):
|
|
for v in var_list:
|
|
self.add_slot(v, "momentum")
|
|
|
|
def compute_lr(self, grad, var, coefficients):
|
|
scaled_lr = coefficients["learning_rate"]
|
|
if self._skip_list is None or not any(v in var.name
|
|
for v in self._skip_list):
|
|
w_norm = linalg_ops.norm(var, ord=2)
|
|
g_norm = linalg_ops.norm(grad, ord=2)
|
|
trust_ratio = array_ops.where(
|
|
math_ops.greater(w_norm, 0),
|
|
array_ops.where(
|
|
math_ops.greater(g_norm, 0),
|
|
(self.eeta * w_norm /
|
|
(g_norm + self.weight_decay * w_norm + self.epsilon)), 1.0), 1.0)
|
|
|
|
scaled_lr = coefficients["learning_rate"] * trust_ratio
|
|
# Add the weight regularization gradient
|
|
grad = grad + self.weight_decay * var
|
|
return scaled_lr, grad
|
|
|
|
def _apply_dense(self, grad, var, apply_state=None):
|
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
|
or self._fallback_apply_state(var_device, var_dtype))
|
|
|
|
scaled_lr, grad = self.compute_lr(grad, var, coefficients)
|
|
mom = self.get_slot(var, "momentum")
|
|
return training_ops.apply_momentum(
|
|
var,
|
|
mom,
|
|
math_ops.cast(1.0, var.dtype.base_dtype),
|
|
grad * scaled_lr,
|
|
self.momentum,
|
|
use_locking=False,
|
|
use_nesterov=self.use_nesterov)
|
|
|
|
def _resource_apply_dense(self, grad, var, apply_state=None):
|
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
|
or self._fallback_apply_state(var_device, var_dtype))
|
|
|
|
scaled_lr, grad = self.compute_lr(grad, var, coefficients)
|
|
mom = self.get_slot(var, "momentum")
|
|
# Use ApplyKerasMomentum instead of ApplyMomentum
|
|
# training_ops.resource_apply_keras_momentum(
|
|
# var.handle,
|
|
# mom.handle,
|
|
# scaled_lr,
|
|
# grad,
|
|
# coefficients["momentum"],
|
|
# use_locking=False,
|
|
# use_nesterov=self.use_nesterov)
|
|
|
|
mom_t = mom * self.momentum - grad * scaled_lr
|
|
mom_t = state_ops.assign(mom, mom_t, use_locking=False)
|
|
if self.use_nesterov:
|
|
var_t = var + mom_t * self.momentum - grad * scaled_lr
|
|
else:
|
|
var_t = var + mom_t
|
|
return state_ops.assign(var, var_t, use_locking=False).op
|
|
|
|
# Fallback to momentum optimizer for sparse tensors
|
|
def _apply_sparse(self, grad, var, apply_state=None):
|
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
|
or self._fallback_apply_state(var_device, var_dtype))
|
|
|
|
mom = self.get_slot(var, "momentum")
|
|
return training_ops.sparse_apply_momentum(
|
|
var,
|
|
mom,
|
|
coefficients["learning_rate"],
|
|
grad.values,
|
|
grad.indices,
|
|
self.momentum,
|
|
use_locking=False,
|
|
use_nesterov=self.use_nesterov)
|
|
|
|
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
|
or self._fallback_apply_state(var_device, var_dtype))
|
|
|
|
mom = self.get_slot(var, "momentum")
|
|
return training_ops.resource_sparse_apply_keras_momentum(
|
|
var.handle,
|
|
mom.handle,
|
|
coefficients["learning_rate"],
|
|
grad,
|
|
indices,
|
|
self.momentum,
|
|
use_locking=False,
|
|
use_nesterov=self.use_nesterov)
|
|
|
|
def get_config(self):
|
|
config = super(LARSOptimizer, self).get_config()
|
|
config.update({
|
|
"learning_rate": self._serialize_hyperparameter("learning_rate"),
|
|
"momentum": self.momentum,
|
|
"weight_decay": self.weight_decay,
|
|
"eeta": self.eeta,
|
|
"epsilon": self.epsilon,
|
|
"use_nesterov": self.use_nesterov,
|
|
})
|
|
return config
|
|
|