You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					234 lines
				
				9.1 KiB
			
		
		
			
		
	
	
					234 lines
				
				9.1 KiB
			| 
											8 months ago
										 | # https://github.com/mlcommons/training/blob/e3769c8dcf88cd21e1001dd2f894b40a1513ec5d/image_classification/tensorflow2/lars_optimizer.py
 | ||
|  | # changes: don't call lr_t if it's not a schedule
 | ||
|  | 
 | ||
|  | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | ||
|  | #
 | ||
|  | # Licensed under the Apache License, Version 2.0 (the "License");
 | ||
|  | # you may not use this file except in compliance with the License.
 | ||
|  | # You may obtain a copy of the License at
 | ||
|  | #
 | ||
|  | #     http://www.apache.org/licenses/LICENSE-2.0
 | ||
|  | #
 | ||
|  | # Unless required by applicable law or agreed to in writing, software
 | ||
|  | # distributed under the License is distributed on an "AS IS" BASIS,
 | ||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
|  | # See the License for the specific language governing permissions and
 | ||
|  | # limitations under the License.
 | ||
|  | # ==============================================================================
 | ||
|  | """Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
 | ||
|  | 
 | ||
|  | from __future__ import absolute_import
 | ||
|  | from __future__ import division
 | ||
|  | from __future__ import print_function
 | ||
|  | 
 | ||
|  | import tensorflow as tf
 | ||
|  | # from tf2_common.training import optimizer_v2modified
 | ||
|  | from tensorflow.python.framework import ops
 | ||
|  | from tensorflow.python.keras import backend_config
 | ||
|  | from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 | ||
|  | from tensorflow.python.ops import array_ops
 | ||
|  | from tensorflow.python.ops import linalg_ops
 | ||
|  | from tensorflow.python.ops import math_ops
 | ||
|  | from tensorflow.python.training import training_ops
 | ||
|  | from tensorflow.python.ops import state_ops
 | ||
|  | 
 | ||
|  | 
 | ||
|  | # class LARSOptimizer(optimizer_v2modified.OptimizerV2Modified):
 | ||
|  | class LARSOptimizer(optimizer_v2.OptimizerV2):
 | ||
|  |   """Layer-wise Adaptive Rate Scaling for large batch training.
 | ||
|  | 
 | ||
|  |   Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
 | ||
|  |   I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
 | ||
|  | 
 | ||
|  |   Implements the LARS learning rate scheme presented in the paper above. This
 | ||
|  |   optimizer is useful when scaling the batch size to up to 32K without
 | ||
|  |   significant performance degradation. It is recommended to use the optimizer
 | ||
|  |   in conjunction with:
 | ||
|  |       - Gradual learning rate warm-up
 | ||
|  |       - Linear learning rate scaling
 | ||
|  |       - Poly rule learning rate decay
 | ||
|  | 
 | ||
|  |   Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
 | ||
|  |   use the default momentum optimizer.
 | ||
|  |   """
 | ||
|  | 
 | ||
|  |   def __init__(
 | ||
|  |       self,
 | ||
|  |       learning_rate,
 | ||
|  |       momentum=0.9,
 | ||
|  |       weight_decay=0.0001,
 | ||
|  |       # The LARS coefficient is a hyperparameter
 | ||
|  |       eeta=0.001,
 | ||
|  |       epsilon=0.0,
 | ||
|  |       name="LARSOptimizer",
 | ||
|  |       # Enable skipping variables from LARS scaling.
 | ||
|  |       # TODO(sameerkm): Enable a direct mechanism to pass a
 | ||
|  |       # subset of variables to the optimizer.
 | ||
|  |       skip_list=None,
 | ||
|  |       use_nesterov=False,
 | ||
|  |       **kwargs):
 | ||
|  |     """Construct a new LARS Optimizer.
 | ||
|  | 
 | ||
|  |     Args:
 | ||
|  |       learning_rate: A `Tensor`, floating point value, or a schedule that is a
 | ||
|  |         `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
 | ||
|  |         that takes no arguments and returns the actual value to use. The
 | ||
|  |         learning rate.
 | ||
|  |       momentum: A floating point value. Momentum hyperparameter.
 | ||
|  |       weight_decay: A floating point value. Weight decay hyperparameter.
 | ||
|  |       eeta: LARS coefficient as used in the paper. Dfault set to LARS
 | ||
|  |         coefficient from the paper. (eeta / weight_decay) determines the highest
 | ||
|  |         scaling factor in LARS.
 | ||
|  |       epsilon: Optional epsilon parameter to be set in models that have very
 | ||
|  |         small gradients. Default set to 0.0.
 | ||
|  |       name: Optional name prefix for variables and ops created by LARSOptimizer.
 | ||
|  |       skip_list: List of strings to enable skipping variables from LARS scaling.
 | ||
|  |         If any of the strings in skip_list is a subset of var.name, variable
 | ||
|  |         'var' is skipped from LARS scaling. For a typical classification model
 | ||
|  |         with batch normalization, the skip_list is ['batch_normalization',
 | ||
|  |         'bias']
 | ||
|  |       use_nesterov: when set to True, nesterov momentum will be enabled
 | ||
|  |       **kwargs: keyword arguments.
 | ||
|  | 
 | ||
|  |     Raises:
 | ||
|  |       ValueError: If a hyperparameter is set to a non-sensical value.
 | ||
|  |     """
 | ||
|  |     if momentum < 0.0:
 | ||
|  |       raise ValueError("momentum should be positive: %s" % momentum)
 | ||
|  |     if weight_decay < 0.0:
 | ||
|  |       raise ValueError("weight_decay should be positive: %s" % weight_decay)
 | ||
|  |     super(LARSOptimizer, self).__init__(name=name, **kwargs)
 | ||
|  | 
 | ||
|  |     self._set_hyper("learning_rate", learning_rate)
 | ||
|  | 
 | ||
|  |     # When directly using class members, instead of
 | ||
|  |     # _set_hyper and _get_hyper (such as learning_rate above),
 | ||
|  |     # the values are fixed after __init(), and not being
 | ||
|  |     # updated during the training process.
 | ||
|  |     # This provides better performance but less flexibility.
 | ||
|  |     self.momentum = momentum
 | ||
|  |     self.weight_decay = weight_decay
 | ||
|  |     self.eeta = eeta
 | ||
|  |     self.epsilon = epsilon or backend_config.epsilon()
 | ||
|  |     self._skip_list = skip_list
 | ||
|  |     self.use_nesterov = use_nesterov
 | ||
|  | 
 | ||
|  |   def _prepare_local(self, var_device, var_dtype, apply_state):
 | ||
|  |     lr_t = self._get_hyper("learning_rate", var_dtype)
 | ||
|  |     local_step = math_ops.cast(self.iterations, var_dtype)
 | ||
|  |     if callable(lr_t): lr_t = math_ops.cast(lr_t(local_step), var_dtype)
 | ||
|  |     learning_rate_t = array_ops.identity(lr_t)
 | ||
|  | 
 | ||
|  |     apply_state[(var_device, var_dtype)].update(
 | ||
|  |         dict(
 | ||
|  |             learning_rate=learning_rate_t,
 | ||
|  |             ))
 | ||
|  | 
 | ||
|  |   def _create_slots(self, var_list):
 | ||
|  |     for v in var_list:
 | ||
|  |       self.add_slot(v, "momentum")
 | ||
|  | 
 | ||
|  |   def compute_lr(self, grad, var, coefficients):
 | ||
|  |     scaled_lr = coefficients["learning_rate"]
 | ||
|  |     if self._skip_list is None or not any(v in var.name
 | ||
|  |                                           for v in self._skip_list):
 | ||
|  |       w_norm = linalg_ops.norm(var, ord=2)
 | ||
|  |       g_norm = linalg_ops.norm(grad, ord=2)
 | ||
|  |       trust_ratio = array_ops.where(
 | ||
|  |           math_ops.greater(w_norm, 0),
 | ||
|  |           array_ops.where(
 | ||
|  |               math_ops.greater(g_norm, 0),
 | ||
|  |               (self.eeta * w_norm /
 | ||
|  |                (g_norm + self.weight_decay * w_norm + self.epsilon)), 1.0), 1.0)
 | ||
|  | 
 | ||
|  |       scaled_lr = coefficients["learning_rate"] * trust_ratio
 | ||
|  |       # Add the weight regularization gradient
 | ||
|  |       grad = grad + self.weight_decay * var
 | ||
|  |     return scaled_lr, grad
 | ||
|  | 
 | ||
|  |   def _apply_dense(self, grad, var, apply_state=None):
 | ||
|  |     var_device, var_dtype = var.device, var.dtype.base_dtype
 | ||
|  |     coefficients = ((apply_state or {}).get((var_device, var_dtype))
 | ||
|  |                     or self._fallback_apply_state(var_device, var_dtype))
 | ||
|  | 
 | ||
|  |     scaled_lr, grad = self.compute_lr(grad, var, coefficients)
 | ||
|  |     mom = self.get_slot(var, "momentum")
 | ||
|  |     return training_ops.apply_momentum(
 | ||
|  |         var,
 | ||
|  |         mom,
 | ||
|  |         math_ops.cast(1.0, var.dtype.base_dtype),
 | ||
|  |         grad * scaled_lr,
 | ||
|  |         self.momentum,
 | ||
|  |         use_locking=False,
 | ||
|  |         use_nesterov=self.use_nesterov)
 | ||
|  | 
 | ||
|  |   def _resource_apply_dense(self, grad, var, apply_state=None):
 | ||
|  |     var_device, var_dtype = var.device, var.dtype.base_dtype
 | ||
|  |     coefficients = ((apply_state or {}).get((var_device, var_dtype))
 | ||
|  |                     or self._fallback_apply_state(var_device, var_dtype))
 | ||
|  | 
 | ||
|  |     scaled_lr, grad = self.compute_lr(grad, var, coefficients)
 | ||
|  |     mom = self.get_slot(var, "momentum")
 | ||
|  |     # Use ApplyKerasMomentum instead of ApplyMomentum
 | ||
|  |     # training_ops.resource_apply_keras_momentum(
 | ||
|  |     #     var.handle,
 | ||
|  |     #     mom.handle,
 | ||
|  |     #     scaled_lr,
 | ||
|  |     #     grad,
 | ||
|  |     #     coefficients["momentum"],
 | ||
|  |     #     use_locking=False,
 | ||
|  |     #     use_nesterov=self.use_nesterov)
 | ||
|  | 
 | ||
|  |     mom_t = mom * self.momentum - grad * scaled_lr
 | ||
|  |     mom_t = state_ops.assign(mom, mom_t, use_locking=False)
 | ||
|  |     if self.use_nesterov:
 | ||
|  |       var_t = var + mom_t * self.momentum - grad * scaled_lr
 | ||
|  |     else:
 | ||
|  |       var_t = var + mom_t
 | ||
|  |     return state_ops.assign(var, var_t, use_locking=False).op
 | ||
|  | 
 | ||
|  |   # Fallback to momentum optimizer for sparse tensors
 | ||
|  |   def _apply_sparse(self, grad, var, apply_state=None):
 | ||
|  |     var_device, var_dtype = var.device, var.dtype.base_dtype
 | ||
|  |     coefficients = ((apply_state or {}).get((var_device, var_dtype))
 | ||
|  |                     or self._fallback_apply_state(var_device, var_dtype))
 | ||
|  | 
 | ||
|  |     mom = self.get_slot(var, "momentum")
 | ||
|  |     return training_ops.sparse_apply_momentum(
 | ||
|  |         var,
 | ||
|  |         mom,
 | ||
|  |         coefficients["learning_rate"],
 | ||
|  |         grad.values,
 | ||
|  |         grad.indices,
 | ||
|  |         self.momentum,
 | ||
|  |         use_locking=False,
 | ||
|  |         use_nesterov=self.use_nesterov)
 | ||
|  | 
 | ||
|  |   def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
 | ||
|  |     var_device, var_dtype = var.device, var.dtype.base_dtype
 | ||
|  |     coefficients = ((apply_state or {}).get((var_device, var_dtype))
 | ||
|  |                     or self._fallback_apply_state(var_device, var_dtype))
 | ||
|  | 
 | ||
|  |     mom = self.get_slot(var, "momentum")
 | ||
|  |     return training_ops.resource_sparse_apply_keras_momentum(
 | ||
|  |         var.handle,
 | ||
|  |         mom.handle,
 | ||
|  |         coefficients["learning_rate"],
 | ||
|  |         grad,
 | ||
|  |         indices,
 | ||
|  |         self.momentum,
 | ||
|  |         use_locking=False,
 | ||
|  |         use_nesterov=self.use_nesterov)
 | ||
|  | 
 | ||
|  |   def get_config(self):
 | ||
|  |     config = super(LARSOptimizer, self).get_config()
 | ||
|  |     config.update({
 | ||
|  |         "learning_rate": self._serialize_hyperparameter("learning_rate"),
 | ||
|  |         "momentum": self.momentum,
 | ||
|  |         "weight_decay": self.weight_decay,
 | ||
|  |         "eeta": self.eeta,
 | ||
|  |         "epsilon": self.epsilon,
 | ||
|  |         "use_nesterov": self.use_nesterov,
 | ||
|  |     })
 | ||
|  |     return config
 |