Module keras.distribute.distributed_training_utils
Utilities related to distributed training.
Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities related to distributed training."""
import tensorflow.compat.v2 as tf
from keras import backend
# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
# no longer needed.
def global_batch_size_supported(distribution_strategy):
return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access
def call_replica_local_fn(fn, *args, **kwargs):
"""Call a function that uses replica-local variables.
This function correctly handles calling `fn` in a cross-replica
context.
Args:
fn: The function to call.
*args: Positional arguments to the `fn`.
**kwargs: Keyword argument to `fn`.
Returns:
The result of calling `fn`.
"""
# TODO(b/132666209): Remove this function when we support assign_*
# for replica-local variables.
strategy = None
if 'strategy' in kwargs:
strategy = kwargs.pop('strategy')
else:
if tf.distribute.has_strategy():
strategy = tf.distribute.get_strategy()
# TODO(b/120571621): TPUStrategy does not implement replica-local variables.
is_tpu = backend.is_tpu_strategy(strategy)
if ((not is_tpu) and strategy and tf.distribute.in_cross_replica_context()):
with strategy.scope():
return strategy.extended.call_for_each_replica(fn, args, kwargs)
return fn(*args, **kwargs)
def is_distributed_variable(v):
"""Returns whether `v` is a distributed variable."""
return (isinstance(v, tf.distribute.DistributedValues) and
isinstance(v, tf.Variable))
Functions
def call_replica_local_fn(fn, *args, **kwargs)
-
Call a function that uses replica-local variables.
This function correctly handles calling
fn
in a cross-replica context.Args
fn
- The function to call.
*args
- Positional arguments to the
fn
. **kwargs
- Keyword argument to
fn
.
Returns
The result of calling
fn
.Expand source code
def call_replica_local_fn(fn, *args, **kwargs): """Call a function that uses replica-local variables. This function correctly handles calling `fn` in a cross-replica context. Args: fn: The function to call. *args: Positional arguments to the `fn`. **kwargs: Keyword argument to `fn`. Returns: The result of calling `fn`. """ # TODO(b/132666209): Remove this function when we support assign_* # for replica-local variables. strategy = None if 'strategy' in kwargs: strategy = kwargs.pop('strategy') else: if tf.distribute.has_strategy(): strategy = tf.distribute.get_strategy() # TODO(b/120571621): TPUStrategy does not implement replica-local variables. is_tpu = backend.is_tpu_strategy(strategy) if ((not is_tpu) and strategy and tf.distribute.in_cross_replica_context()): with strategy.scope(): return strategy.extended.call_for_each_replica(fn, args, kwargs) return fn(*args, **kwargs)
def global_batch_size_supported(distribution_strategy)
-
Expand source code
def global_batch_size_supported(distribution_strategy): return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access
def is_distributed_variable(v)
-
Returns whether
v
is a distributed variable.Expand source code
def is_distributed_variable(v): """Returns whether `v` is a distributed variable.""" return (isinstance(v, tf.distribute.DistributedValues) and isinstance(v, tf.Variable))