Module keras.utils.layer_utils

Utilities related to layer/model functionality.

Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities related to layer/model functionality."""

import tensorflow.compat.v2 as tf

import functools
import weakref

import numpy as np
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
  """Returns the list of input tensors necessary to compute `tensor`.

  Output will always be a list of tensors
  (potentially with 1 element).

  Args:
      tensor: The tensor to start from.
      layer: Origin layer of the tensor. Will be
          determined via tensor._keras_history if not provided.
      node_index: Origin node index of the tensor.

  Returns:
      List of input tensors.
  """
  if not hasattr(tensor, '_keras_history'):
    return tensor

  if layer is None or node_index:
    layer, node_index, _ = tensor._keras_history
  if not layer._inbound_nodes:
    return [tensor]
  else:
    node = layer._inbound_nodes[node_index]
    if node.is_input:
      # Reached an Input layer, stop recursion.
      return tf.nest.flatten(node.input_tensors)
    else:
      source_tensors = []
      for layer, node_index, _, tensor in node.iterate_inbound():
        previous_sources = get_source_inputs(tensor, layer, node_index)
        # Avoid input redundancy.
        for x in previous_sources:
          if all(x is not t for t in source_tensors):
            source_tensors.append(x)
      return source_tensors


def validate_string_arg(input_data,
                        allowable_strings,
                        layer_name,
                        arg_name,
                        allow_none=False,
                        allow_callables=False):
  """Validates the correctness of a string-based arg."""
  if allow_none and input_data is None:
    return
  elif allow_callables and callable(input_data):
    return
  elif isinstance(input_data, str) and input_data in allowable_strings:
    return
  else:
    allowed_args = '`None`, ' if allow_none else ''
    allowed_args += 'a `Callable`, ' if allow_callables else ''
    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
    raise ValueError(('The %s argument of layer %s received an invalid '
                      'value %s. Allowed values are: %s.') %
                     (arg_name, layer_name, input_data, allowed_args))


def count_params(weights):
  """Count the total number of scalars composing the weights.

  Args:
      weights: An iterable containing the weights on which to compute params

  Returns:
      The total number of scalars composing the weights
  """
  unique_weights = {id(w): w for w in weights}.values()
  # Ignore TrackableWeightHandlers, which will not have a shape defined.
  unique_weights = [w for w in unique_weights if hasattr(w, 'shape')]
  weight_shapes = [w.shape.as_list() for w in unique_weights]
  standardized_weight_shapes = [
      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
  ]
  return int(sum(np.prod(p) for p in standardized_weight_shapes))


def print_summary(model, line_length=None, positions=None, print_fn=None):
  """Prints a summary of a model.

  Args:
      model: Keras model instance.
      line_length: Total length of printed lines
          (e.g. set this to adapt the display to different
          terminal window sizes).
      positions: Relative or absolute positions of log elements in each line.
          If not provided, defaults to `[.33, .55, .67, 1.]`.
      print_fn: Print function to use.
          It will be called on each line of the summary.
          You can set it to a custom function
          in order to capture the string summary.
          It defaults to `print` (prints to stdout).
  """
  if print_fn is None:
    print_fn = print

  if model.__class__.__name__ == 'Sequential':
    sequential_like = True
  elif not model._is_graph_network:
    # We treat subclassed models as a simple sequence of layers, for logging
    # purposes.
    sequential_like = True
  else:
    sequential_like = True
    nodes_by_depth = model._nodes_by_depth.values()
    nodes = []
    for v in nodes_by_depth:
      if (len(v) > 1) or (len(v) == 1 and
                          len(tf.nest.flatten(v[0].keras_inputs)) > 1):
        # if the model has multiple nodes
        # or if the nodes have multiple inbound_layers
        # the model is no longer sequential
        sequential_like = False
        break
      nodes += v
    if sequential_like:
      # search for shared layers
      for layer in model.layers:
        flag = False
        for node in layer._inbound_nodes:
          if node in nodes:
            if flag:
              sequential_like = False
              break
            else:
              flag = True
        if not sequential_like:
          break

  if sequential_like:
    line_length = line_length or 65
    positions = positions or [.45, .85, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #']
  else:
    line_length = line_length or 98
    positions = positions or [.33, .55, .67, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
    relevant_nodes = []
    for v in model._nodes_by_depth.values():
      relevant_nodes += v

  def print_row(fields, positions):
    line = ''
    for i in range(len(fields)):
      if i > 0:
        line = line[:-1] + ' '
      line += str(fields[i])
      line = line[:positions[i]]
      line += ' ' * (positions[i] - len(line))
    print_fn(line)

  print_fn('Model: "{}"'.format(model.name))
  print_fn('_' * line_length)
  print_row(to_display, positions)
  print_fn('=' * line_length)

  def print_layer_summary(layer):
    """Prints a summary for a single layer.

    Args:
        layer: target layer.
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    except RuntimeError:  # output_shape unknown in Eager mode.
      output_shape = '?'
    name = layer.name
    cls_name = layer.__class__.__name__
    if not layer.built and not getattr(layer, '_is_graph_network', False):
      # If a subclassed model has a layer that is not called in Model.call, the
      # layer will not be built and we cannot call layer.count_params().
      params = '0 (unused)'
    else:
      params = layer.count_params()
    fields = [name + ' (' + cls_name + ')', output_shape, params]
    print_row(fields, positions)

  def print_layer_summary_with_connections(layer):
    """Prints a summary for a single layer (including topological connections).

    Args:
        layer: target layer.
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    connections = []
    for node in layer._inbound_nodes:
      if relevant_nodes and node not in relevant_nodes:
        # node is not part of the current network
        continue

      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
                                               tensor_index))

    name = layer.name
    cls_name = layer.__class__.__name__
    if not connections:
      first_connection = ''
    else:
      first_connection = connections[0]
    fields = [
        name + ' (' + cls_name + ')', output_shape,
        layer.count_params(), first_connection
    ]
    print_row(fields, positions)
    if len(connections) > 1:
      for i in range(1, len(connections)):
        fields = ['', '', '', connections[i]]
        print_row(fields, positions)

  layers = model.layers
  for i in range(len(layers)):
    if sequential_like:
      print_layer_summary(layers[i])
    else:
      print_layer_summary_with_connections(layers[i])
    if i == len(layers) - 1:
      print_fn('=' * line_length)
    else:
      print_fn('_' * line_length)

  if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
  else:
    trainable_count = count_params(model.trainable_weights)

  non_trainable_count = count_params(model.non_trainable_weights)

  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
  print_fn('Trainable params: {:,}'.format(trainable_count))
  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
  print_fn('_' * line_length)


def convert_dense_weights_data_format(dense,
                                      previous_feature_map_shape,
                                      target_data_format='channels_first'):
  """Utility useful when changing a convnet's `data_format`.

  When porting the weights of a convnet from one data format to the other,
  if the convnet includes a `Flatten` layer
  (applied to the last convolutional feature map)
  followed by a `Dense` layer, the weights of that `Dense` layer
  should be updated to reflect the new dimension ordering.

  Args:
      dense: The target `Dense` layer.
      previous_feature_map_shape: A shape tuple of 3 integers,
          e.g. `(512, 7, 7)`. The shape of the convolutional
          feature map right before the `Flatten` layer that
          came before the target `Dense` layer.
      target_data_format: One of "channels_last", "channels_first".
          Set it "channels_last"
          if converting a "channels_first" model to "channels_last",
          or reciprocally.
  """
  assert target_data_format in {'channels_last', 'channels_first'}
  kernel, bias = dense.get_weights()
  for i in range(kernel.shape[1]):
    if target_data_format == 'channels_first':
      c, h, w = previous_feature_map_shape
      original_fm_shape = (h, w, c)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (2, 0, 1))  # last -> first
    else:
      h, w, c = previous_feature_map_shape
      original_fm_shape = (c, h, w)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (1, 2, 0))  # first -> last
    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
  dense.set_weights([kernel, bias])


def is_builtin_layer(layer):
  if not getattr(layer, '_keras_api_names', None):
    return False

  # Subclasses of `Layer` that are not exported inherit the export name
  # of the base layer class.
  return (layer._keras_api_names != ('keras.layers.Layer',) and
          layer._keras_api_names_v1 != ('keras.layers.Layer',))


def cached_per_instance(f):
  """Lightweight decorator for caching lazily constructed properties.

  When to use:
  This decorator provides simple caching with minimal overhead. It is designed
  for properties which are expensive to compute and static over the life of a
  class instance, and provides no mechanism for cache invalidation. Thus it is
  best suited for lazily exposing derived properties of other static data.

  For classes with custom getattr / setattr behavior (such as trackable
  objects), storing cache results as object attributes is not performant.
  Instead, a specialized cache can significantly reduce property lookup
  overhead. (While still allowing the decorated property to be lazily computed.)
  Consider the following class:

  ```
  class MyClass(object):
    def __setattr__(self, key, value):
      # Some expensive class specific code
      # ...
      # ...

      super(MyClass, self).__setattr__(key, value)

    @property
    def thing(self):
      # `thing` is expensive to compute (and may not even be requested), so we
      # want to lazily compute it and then cache it.
      output = getattr(self, '_thing', None)
      if output is None:
        self._thing = output = compute_thing(self)
      return output
  ```

  It's also worth noting that ANY overriding of __setattr__, even something as
  simple as:
  ```
    def __setattr__(self, key, value):
      super(MyClass, self).__setattr__(key, value)
  ```

  Slows down attribute assignment by nearly 10x.

  By contrast, replacing the definition of `thing` with the following sidesteps
  the expensive __setattr__ altogether:

  '''
  @property
  @tracking.cached_per_instance
  def thing(self):
    # `thing` is expensive to compute (and may not even be requested), so we
    # want to lazily compute it and then cache it.
    return compute_thing(self)
  '''

  Performance:
  The overhead for this decorator is ~0.4 us / call. A much lower overhead
  implementation (~0.085 us / call) can be achieved by using a custom dict type:

  ```
  def dict_based_cache(f):
    class Cache(dict):
      __slots__ = ()
      def __missing__(self, key):
        self[key] = output = f(key)
        return output

    return property(Cache().__getitem__)
  ```

  However, that implementation holds class instances as keys, and as a result
  blocks garbage collection. (And modifying it to use weakref's as keys raises
  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
  implementation below turns out to be more prudent.

  Args:
    f: The function to cache.

  Returns:
    f decorated with simple caching behavior.
  """

  cache = weakref.WeakKeyDictionary()

  @functools.wraps(f)
  def wrapped(item):
    output = cache.get(item)
    if output is None:
      cache[item] = output = f(item)
    return output

  wrapped.cache = cache
  return wrapped


def filter_empty_layer_containers(layer_list):
  """Filter out empty Layer-like containers and uniquify."""
  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
  existing = set()
  to_visit = layer_list[::-1]
  while to_visit:
    obj = to_visit.pop()
    if id(obj) in existing:
      continue
    existing.add(id(obj))
    if hasattr(obj, '_is_layer') and not isinstance(obj, type):
      yield obj
    else:
      sub_layers = getattr(obj, 'layers', None) or []

      # Trackable data structures will not show up in ".layers" lists, but
      # the layers they contain will.
      to_visit.extend(sub_layers[::-1])

Functions

def cached_per_instance(f)

Lightweight decorator for caching lazily constructed properties.

When to use: This decorator provides simple caching with minimal overhead. It is designed for properties which are expensive to compute and static over the life of a class instance, and provides no mechanism for cache invalidation. Thus it is best suited for lazily exposing derived properties of other static data.

For classes with custom getattr / setattr behavior (such as trackable objects), storing cache results as object attributes is not performant. Instead, a specialized cache can significantly reduce property lookup overhead. (While still allowing the decorated property to be lazily computed.) Consider the following class:

class MyClass(object):
  def __setattr__(self, key, value):
    # Some expensive class specific code
    # ...
    # ...

    super(MyClass, self).__setattr__(key, value)

  @property
  def thing(self):
    # `thing` is expensive to compute (and may not even be requested), so we
    # want to lazily compute it and then cache it.
    output = getattr(self, '_thing', None)
    if output is None:
      self._thing = output = compute_thing(self)
    return output

It's also worth noting that ANY overriding of setattr, even something as simple as:

  def __setattr__(self, key, value):
    super(MyClass, self).__setattr__(key, value)

Slows down attribute assignment by nearly 10x.

By contrast, replacing the definition of thing with the following sidesteps the expensive setattr altogether:

''' @property @tracking.cached_per_instance def thing(self): # thing is expensive to compute (and may not even be requested), so we # want to lazily compute it and then cache it. return compute_thing(self) '''

Performance: The overhead for this decorator is ~0.4 us / call. A much lower overhead implementation (~0.085 us / call) can be achieved by using a custom dict type:

def dict_based_cache(f):
  class Cache(dict):
    __slots__ = ()
    def __missing__(self, key):
      self[key] = output = f(key)
      return output

  return property(Cache().__getitem__)

However, that implementation holds class instances as keys, and as a result blocks garbage collection. (And modifying it to use weakref's as keys raises the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary implementation below turns out to be more prudent.

Args

f
The function to cache.

Returns

f decorated with simple caching behavior.

Expand source code
def cached_per_instance(f):
  """Lightweight decorator for caching lazily constructed properties.

  When to use:
  This decorator provides simple caching with minimal overhead. It is designed
  for properties which are expensive to compute and static over the life of a
  class instance, and provides no mechanism for cache invalidation. Thus it is
  best suited for lazily exposing derived properties of other static data.

  For classes with custom getattr / setattr behavior (such as trackable
  objects), storing cache results as object attributes is not performant.
  Instead, a specialized cache can significantly reduce property lookup
  overhead. (While still allowing the decorated property to be lazily computed.)
  Consider the following class:

  ```
  class MyClass(object):
    def __setattr__(self, key, value):
      # Some expensive class specific code
      # ...
      # ...

      super(MyClass, self).__setattr__(key, value)

    @property
    def thing(self):
      # `thing` is expensive to compute (and may not even be requested), so we
      # want to lazily compute it and then cache it.
      output = getattr(self, '_thing', None)
      if output is None:
        self._thing = output = compute_thing(self)
      return output
  ```

  It's also worth noting that ANY overriding of __setattr__, even something as
  simple as:
  ```
    def __setattr__(self, key, value):
      super(MyClass, self).__setattr__(key, value)
  ```

  Slows down attribute assignment by nearly 10x.

  By contrast, replacing the definition of `thing` with the following sidesteps
  the expensive __setattr__ altogether:

  '''
  @property
  @tracking.cached_per_instance
  def thing(self):
    # `thing` is expensive to compute (and may not even be requested), so we
    # want to lazily compute it and then cache it.
    return compute_thing(self)
  '''

  Performance:
  The overhead for this decorator is ~0.4 us / call. A much lower overhead
  implementation (~0.085 us / call) can be achieved by using a custom dict type:

  ```
  def dict_based_cache(f):
    class Cache(dict):
      __slots__ = ()
      def __missing__(self, key):
        self[key] = output = f(key)
        return output

    return property(Cache().__getitem__)
  ```

  However, that implementation holds class instances as keys, and as a result
  blocks garbage collection. (And modifying it to use weakref's as keys raises
  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
  implementation below turns out to be more prudent.

  Args:
    f: The function to cache.

  Returns:
    f decorated with simple caching behavior.
  """

  cache = weakref.WeakKeyDictionary()

  @functools.wraps(f)
  def wrapped(item):
    output = cache.get(item)
    if output is None:
      cache[item] = output = f(item)
    return output

  wrapped.cache = cache
  return wrapped
def convert_dense_weights_data_format(dense, previous_feature_map_shape, target_data_format='channels_first')

Utility useful when changing a convnet's data_format.

When porting the weights of a convnet from one data format to the other, if the convnet includes a Flatten layer (applied to the last convolutional feature map) followed by a Dense layer, the weights of that Dense layer should be updated to reflect the new dimension ordering.

Args

dense
The target Dense layer.
previous_feature_map_shape
A shape tuple of 3 integers, e.g. (512, 7, 7). The shape of the convolutional feature map right before the Flatten layer that came before the target Dense layer.
target_data_format
One of "channels_last", "channels_first". Set it "channels_last" if converting a "channels_first" model to "channels_last", or reciprocally.
Expand source code
def convert_dense_weights_data_format(dense,
                                      previous_feature_map_shape,
                                      target_data_format='channels_first'):
  """Utility useful when changing a convnet's `data_format`.

  When porting the weights of a convnet from one data format to the other,
  if the convnet includes a `Flatten` layer
  (applied to the last convolutional feature map)
  followed by a `Dense` layer, the weights of that `Dense` layer
  should be updated to reflect the new dimension ordering.

  Args:
      dense: The target `Dense` layer.
      previous_feature_map_shape: A shape tuple of 3 integers,
          e.g. `(512, 7, 7)`. The shape of the convolutional
          feature map right before the `Flatten` layer that
          came before the target `Dense` layer.
      target_data_format: One of "channels_last", "channels_first".
          Set it "channels_last"
          if converting a "channels_first" model to "channels_last",
          or reciprocally.
  """
  assert target_data_format in {'channels_last', 'channels_first'}
  kernel, bias = dense.get_weights()
  for i in range(kernel.shape[1]):
    if target_data_format == 'channels_first':
      c, h, w = previous_feature_map_shape
      original_fm_shape = (h, w, c)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (2, 0, 1))  # last -> first
    else:
      h, w, c = previous_feature_map_shape
      original_fm_shape = (c, h, w)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (1, 2, 0))  # first -> last
    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
  dense.set_weights([kernel, bias])
def count_params(weights)

Count the total number of scalars composing the weights.

Args

weights
An iterable containing the weights on which to compute params

Returns

The total number of scalars composing the weights

Expand source code
def count_params(weights):
  """Count the total number of scalars composing the weights.

  Args:
      weights: An iterable containing the weights on which to compute params

  Returns:
      The total number of scalars composing the weights
  """
  unique_weights = {id(w): w for w in weights}.values()
  # Ignore TrackableWeightHandlers, which will not have a shape defined.
  unique_weights = [w for w in unique_weights if hasattr(w, 'shape')]
  weight_shapes = [w.shape.as_list() for w in unique_weights]
  standardized_weight_shapes = [
      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
  ]
  return int(sum(np.prod(p) for p in standardized_weight_shapes))
def filter_empty_layer_containers(layer_list)

Filter out empty Layer-like containers and uniquify.

Expand source code
def filter_empty_layer_containers(layer_list):
  """Filter out empty Layer-like containers and uniquify."""
  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
  existing = set()
  to_visit = layer_list[::-1]
  while to_visit:
    obj = to_visit.pop()
    if id(obj) in existing:
      continue
    existing.add(id(obj))
    if hasattr(obj, '_is_layer') and not isinstance(obj, type):
      yield obj
    else:
      sub_layers = getattr(obj, 'layers', None) or []

      # Trackable data structures will not show up in ".layers" lists, but
      # the layers they contain will.
      to_visit.extend(sub_layers[::-1])
def get_source_inputs(tensor, layer=None, node_index=None)

Returns the list of input tensors necessary to compute tensor.

Output will always be a list of tensors (potentially with 1 element).

Args

tensor
The tensor to start from.
layer
Origin layer of the tensor. Will be determined via tensor._keras_history if not provided.
node_index
Origin node index of the tensor.

Returns

List of input tensors.

Expand source code
@keras_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
  """Returns the list of input tensors necessary to compute `tensor`.

  Output will always be a list of tensors
  (potentially with 1 element).

  Args:
      tensor: The tensor to start from.
      layer: Origin layer of the tensor. Will be
          determined via tensor._keras_history if not provided.
      node_index: Origin node index of the tensor.

  Returns:
      List of input tensors.
  """
  if not hasattr(tensor, '_keras_history'):
    return tensor

  if layer is None or node_index:
    layer, node_index, _ = tensor._keras_history
  if not layer._inbound_nodes:
    return [tensor]
  else:
    node = layer._inbound_nodes[node_index]
    if node.is_input:
      # Reached an Input layer, stop recursion.
      return tf.nest.flatten(node.input_tensors)
    else:
      source_tensors = []
      for layer, node_index, _, tensor in node.iterate_inbound():
        previous_sources = get_source_inputs(tensor, layer, node_index)
        # Avoid input redundancy.
        for x in previous_sources:
          if all(x is not t for t in source_tensors):
            source_tensors.append(x)
      return source_tensors
def is_builtin_layer(layer)
Expand source code
def is_builtin_layer(layer):
  if not getattr(layer, '_keras_api_names', None):
    return False

  # Subclasses of `Layer` that are not exported inherit the export name
  # of the base layer class.
  return (layer._keras_api_names != ('keras.layers.Layer',) and
          layer._keras_api_names_v1 != ('keras.layers.Layer',))
def print_summary(model, line_length=None, positions=None, print_fn=None)

Prints a summary of a model.

Args

model
Keras model instance.
line_length
Total length of printed lines (e.g. set this to adapt the display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line. If not provided, defaults to [.33, .55, .67, 1.].
print_fn
Print function to use. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary. It defaults to print (prints to stdout).
Expand source code
def print_summary(model, line_length=None, positions=None, print_fn=None):
  """Prints a summary of a model.

  Args:
      model: Keras model instance.
      line_length: Total length of printed lines
          (e.g. set this to adapt the display to different
          terminal window sizes).
      positions: Relative or absolute positions of log elements in each line.
          If not provided, defaults to `[.33, .55, .67, 1.]`.
      print_fn: Print function to use.
          It will be called on each line of the summary.
          You can set it to a custom function
          in order to capture the string summary.
          It defaults to `print` (prints to stdout).
  """
  if print_fn is None:
    print_fn = print

  if model.__class__.__name__ == 'Sequential':
    sequential_like = True
  elif not model._is_graph_network:
    # We treat subclassed models as a simple sequence of layers, for logging
    # purposes.
    sequential_like = True
  else:
    sequential_like = True
    nodes_by_depth = model._nodes_by_depth.values()
    nodes = []
    for v in nodes_by_depth:
      if (len(v) > 1) or (len(v) == 1 and
                          len(tf.nest.flatten(v[0].keras_inputs)) > 1):
        # if the model has multiple nodes
        # or if the nodes have multiple inbound_layers
        # the model is no longer sequential
        sequential_like = False
        break
      nodes += v
    if sequential_like:
      # search for shared layers
      for layer in model.layers:
        flag = False
        for node in layer._inbound_nodes:
          if node in nodes:
            if flag:
              sequential_like = False
              break
            else:
              flag = True
        if not sequential_like:
          break

  if sequential_like:
    line_length = line_length or 65
    positions = positions or [.45, .85, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #']
  else:
    line_length = line_length or 98
    positions = positions or [.33, .55, .67, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
    relevant_nodes = []
    for v in model._nodes_by_depth.values():
      relevant_nodes += v

  def print_row(fields, positions):
    line = ''
    for i in range(len(fields)):
      if i > 0:
        line = line[:-1] + ' '
      line += str(fields[i])
      line = line[:positions[i]]
      line += ' ' * (positions[i] - len(line))
    print_fn(line)

  print_fn('Model: "{}"'.format(model.name))
  print_fn('_' * line_length)
  print_row(to_display, positions)
  print_fn('=' * line_length)

  def print_layer_summary(layer):
    """Prints a summary for a single layer.

    Args:
        layer: target layer.
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    except RuntimeError:  # output_shape unknown in Eager mode.
      output_shape = '?'
    name = layer.name
    cls_name = layer.__class__.__name__
    if not layer.built and not getattr(layer, '_is_graph_network', False):
      # If a subclassed model has a layer that is not called in Model.call, the
      # layer will not be built and we cannot call layer.count_params().
      params = '0 (unused)'
    else:
      params = layer.count_params()
    fields = [name + ' (' + cls_name + ')', output_shape, params]
    print_row(fields, positions)

  def print_layer_summary_with_connections(layer):
    """Prints a summary for a single layer (including topological connections).

    Args:
        layer: target layer.
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    connections = []
    for node in layer._inbound_nodes:
      if relevant_nodes and node not in relevant_nodes:
        # node is not part of the current network
        continue

      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
                                               tensor_index))

    name = layer.name
    cls_name = layer.__class__.__name__
    if not connections:
      first_connection = ''
    else:
      first_connection = connections[0]
    fields = [
        name + ' (' + cls_name + ')', output_shape,
        layer.count_params(), first_connection
    ]
    print_row(fields, positions)
    if len(connections) > 1:
      for i in range(1, len(connections)):
        fields = ['', '', '', connections[i]]
        print_row(fields, positions)

  layers = model.layers
  for i in range(len(layers)):
    if sequential_like:
      print_layer_summary(layers[i])
    else:
      print_layer_summary_with_connections(layers[i])
    if i == len(layers) - 1:
      print_fn('=' * line_length)
    else:
      print_fn('_' * line_length)

  if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
  else:
    trainable_count = count_params(model.trainable_weights)

  non_trainable_count = count_params(model.non_trainable_weights)

  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
  print_fn('Trainable params: {:,}'.format(trainable_count))
  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
  print_fn('_' * line_length)
def validate_string_arg(input_data, allowable_strings, layer_name, arg_name, allow_none=False, allow_callables=False)

Validates the correctness of a string-based arg.

Expand source code
def validate_string_arg(input_data,
                        allowable_strings,
                        layer_name,
                        arg_name,
                        allow_none=False,
                        allow_callables=False):
  """Validates the correctness of a string-based arg."""
  if allow_none and input_data is None:
    return
  elif allow_callables and callable(input_data):
    return
  elif isinstance(input_data, str) and input_data in allowable_strings:
    return
  else:
    allowed_args = '`None`, ' if allow_none else ''
    allowed_args += 'a `Callable`, ' if allow_callables else ''
    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
    raise ValueError(('The %s argument of layer %s received an invalid '
                      'value %s. Allowed values are: %s.') %
                     (arg_name, layer_name, input_data, allowed_args))