Module keras.utils.control_flow_util
Utility functions for control flow.
This file is copied from tensorflow/python/ops/control_flow_util.py.
Expand source code
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for control flow.
This file is copied from tensorflow/python/ops/control_flow_util.py.
"""
import tensorflow.compat.v2 as tf
def InXlaContext(graph):
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
return GetContainingXLAContext(ctxt) is not None
def GraphOrParentsInXlaContext(graph):
while True:
if InXlaContext(graph): return True
try:
graph = graph.outer_graph
except AttributeError:
return False
def IsInWhileLoop(op):
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
return GetContainingWhileContext(ctxt) is not None
def GetContainingWhileContext(ctxt, stop_ctxt=None):
"""Returns the first ancestor WhileContext of `ctxt`.
Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
while loop.
Args:
ctxt: ControlFlowContext
stop_ctxt: ControlFlowContext, optional. If provided, the search will end
if it sees stop_ctxt.
Returns:
`ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
`ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
`None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
"""
while ctxt:
if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
ctxt = ctxt.outer_context
return None
def GetContainingXLAContext(ctxt):
"""Returns the first ancestor XLAContext of `ctxt`.
Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
while loop.
Args:
ctxt: ControlFlowContext
Returns:
`ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
`ctxt`, or None if `ctxt` is not in a while loop.
"""
while ctxt:
if ctxt.IsXLAContext(): return ctxt
ctxt = ctxt.outer_context
return None
def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable=invalid-name
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Args:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if isinstance(pred, tf.Variable):
return tf.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
return tf.__internal__.smart_cond.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
def constant_value(pred): # pylint: disable=invalid-name
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
Args:
pred: A scalar, either a Python bool or a TensorFlow boolean variable
or tensor, or the Python integer 1 or 0.
Returns:
True or False if `pred` has a constant boolean value, None otherwise.
Raises:
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
integer 1 or 0.
"""
if isinstance(pred, tf.Tensor):
return tf.get_static_value(pred)
if pred in {0, 1}: # Accept 1/0 as valid boolean values
return bool(pred)
if isinstance(pred, bool):
return pred
if isinstance(pred, tf.Variable):
return None
raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. "
"Found instead: %s" % type(pred))
Functions
def GetContainingWhileContext(ctxt, stop_ctxt=None)
-
Returns the first ancestor WhileContext of
ctxt
.Returns
ctxt
ifctxt
is a WhileContext, or None ifctxt
is not in a while loop.Args
ctxt
- ControlFlowContext
stop_ctxt
- ControlFlowContext, optional. If provided, the search will end if it sees stop_ctxt.
Returns
ctxt
ifctxt
is a WhileContext, the most nested WhileContext containingctxt
, or None ifctxt
is not in a while loop. Ifstop_ctxt
is notNone
, this returnsctxt
if it matchesstop_ctxt
in its traversal.Expand source code
def GetContainingWhileContext(ctxt, stop_ctxt=None): """Returns the first ancestor WhileContext of `ctxt`. Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a while loop. Args: ctxt: ControlFlowContext stop_ctxt: ControlFlowContext, optional. If provided, the search will end if it sees stop_ctxt. Returns: `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. """ while ctxt: if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt ctxt = ctxt.outer_context return None
def GetContainingXLAContext(ctxt)
-
Returns the first ancestor XLAContext of
ctxt
.Returns
ctxt
ifctxt
is a XLAContext, or None ifctxt
is not in a while loop.Args
ctxt
- ControlFlowContext
Returns
ctxt
ifctxt
is a XLAContext, the most nested XLAContext containingctxt
, or None ifctxt
is not in a while loop.Expand source code
def GetContainingXLAContext(ctxt): """Returns the first ancestor XLAContext of `ctxt`. Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a while loop. Args: ctxt: ControlFlowContext Returns: `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing `ctxt`, or None if `ctxt` is not in a while loop. """ while ctxt: if ctxt.IsXLAContext(): return ctxt ctxt = ctxt.outer_context return None
def GraphOrParentsInXlaContext(graph)
-
Expand source code
def GraphOrParentsInXlaContext(graph): while True: if InXlaContext(graph): return True try: graph = graph.outer_graph except AttributeError: return False
def InXlaContext(graph)
-
Expand source code
def InXlaContext(graph): ctxt = graph._get_control_flow_context() # pylint: disable=protected-access return GetContainingXLAContext(ctxt) is not None
def IsInWhileLoop(op)
-
Expand source code
def IsInWhileLoop(op): ctxt = op._get_control_flow_context() # pylint: disable=protected-access return GetContainingWhileContext(ctxt) is not None
def constant_value(pred)
-
Return the bool value for
pred
, or None ifpred
had a dynamic value.Args
pred
- A scalar, either a Python bool or a TensorFlow boolean variable or tensor, or the Python integer 1 or 0.
Returns
True or False if
pred
has a constant boolean value, None otherwise.Raises
TypeError
- If
pred
is not a Variable, Tensor or bool, or Python integer 1 or 0.
Expand source code
def constant_value(pred): # pylint: disable=invalid-name """Return the bool value for `pred`, or None if `pred` had a dynamic value. Args: pred: A scalar, either a Python bool or a TensorFlow boolean variable or tensor, or the Python integer 1 or 0. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Variable, Tensor or bool, or Python integer 1 or 0. """ if isinstance(pred, tf.Tensor): return tf.get_static_value(pred) if pred in {0, 1}: # Accept 1/0 as valid boolean values return bool(pred) if isinstance(pred, bool): return pred if isinstance(pred, tf.Variable): return None raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % type(pred))
def smart_cond(pred, true_fn=None, false_fn=None, name=None)
-
Return either
true_fn()
if predicatepred
is true elsefalse_fn()
.If
pred
is a bool or has a constant value, we return eithertrue_fn()
orfalse_fn()
, otherwise we usetf.cond
to dynamically route to both.Args
pred
- A scalar determining whether to return the result of
true_fn
orfalse_fn
. true_fn
- The callable to be performed if pred is true.
false_fn
- The callable to be performed if pred is false.
name
- Optional name prefix when using
tf.cond
.
Returns
Tensors returned by the call to either
true_fn
orfalse_fn
.Raises
TypeError
- If
true_fn
orfalse_fn
is not callable.
Expand source code
def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable=invalid-name """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. If `pred` is a bool or has a constant value, we return either `true_fn()` or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Args: pred: A scalar determining whether to return the result of `true_fn` or `false_fn`. true_fn: The callable to be performed if pred is true. false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: Tensors returned by the call to either `true_fn` or `false_fn`. Raises: TypeError: If `true_fn` or `false_fn` is not callable. """ if isinstance(pred, tf.Variable): return tf.cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) return tf.__internal__.smart_cond.smart_cond( pred, true_fn=true_fn, false_fn=false_fn, name=name)