Module keras.engine.input_spec

Contains the InputSpec class.

Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-classes-have-attributes
"""Contains the InputSpec class."""

import tensorflow.compat.v2 as tf
from keras import backend
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export


@keras_export('keras.layers.InputSpec',
              v1=['keras.layers.InputSpec',
                  'keras.__internal__.legacy.layers.InputSpec'])
@tf_export(v1=['layers.InputSpec'])
class InputSpec(object):
  """Specifies the rank, dtype and shape of every input to a layer.

  Layers can expose (if appropriate) an `input_spec` attribute:
  an instance of `InputSpec`, or a nested structure of `InputSpec` instances
  (one per input tensor). These objects enable the layer to run input
  compatibility checks for input structure, input rank, input shape, and
  input dtype.

  A None entry in a shape is compatible with any dimension,
  a None shape is compatible with any shape.

  Args:
    dtype: Expected DataType of the input.
    shape: Shape tuple, expected shape of the input
      (may include None for unchecked axes). Includes the batch size.
    ndim: Integer, expected rank of the input.
    max_ndim: Integer, maximum rank of the input.
    min_ndim: Integer, minimum rank of the input.
    axes: Dictionary mapping integer axes to
      a specific dimension value.
    allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
      as the last axis of the input is 1, as well as inputs of rank N-1
      as long as the last axis of the spec is 1.
    name: Expected key corresponding to this input when passing data as
      a dictionary.

  Example:

  ```python
  class MyLayer(Layer):
      def __init__(self):
          super(MyLayer, self).__init__()
          # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
          # and raise an appropriate error message otherwise.
          self.input_spec = InputSpec(
              shape=(None, 28, 28, 1),
              allow_last_axis_squeeze=True)
  ```
  """

  def __init__(self,
               dtype=None,
               shape=None,
               ndim=None,
               max_ndim=None,
               min_ndim=None,
               axes=None,
               allow_last_axis_squeeze=False,
               name=None):
    self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
    shape = tf.TensorShape(shape)
    if shape.rank is None:
      shape = None
    else:
      shape = tuple(shape.as_list())
    if shape is not None:
      self.ndim = len(shape)
      self.shape = shape
    else:
      self.ndim = ndim
      self.shape = None
    self.max_ndim = max_ndim
    self.min_ndim = min_ndim
    self.name = name
    self.allow_last_axis_squeeze = allow_last_axis_squeeze
    try:
      axes = axes or {}
      self.axes = {int(k): axes[k] for k in axes}
    except (ValueError, TypeError):
      raise TypeError('The keys in axes must be integers.')

    if self.axes and (self.ndim is not None or self.max_ndim is not None):
      max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
      max_axis = max(self.axes)
      if max_axis > max_dim:
        raise ValueError('Axis {} is greater than the maximum allowed value: {}'
                         .format(max_axis, max_dim))

  def __repr__(self):
    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
            ('shape=' + str(self.shape)) if self.shape else '',
            ('ndim=' + str(self.ndim)) if self.ndim else '',
            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
            ('axes=' + str(self.axes)) if self.axes else '']
    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)

  def get_config(self):
    return {
        'dtype': self.dtype,
        'shape': self.shape,
        'ndim': self.ndim,
        'max_ndim': self.max_ndim,
        'min_ndim': self.min_ndim,
        'axes': self.axes}

  @classmethod
  def from_config(cls, config):
    return cls(**config)


def to_tensor_shape(spec):
  """Returns a tf.TensorShape object that matches the shape specifications.

  If the InputSpec's shape or ndim is defined, this method will return a fully
  or partially-known shape. Otherwise, the returned TensorShape is None.

  Args:
    spec: an InputSpec object.

  Returns:
    a tf.TensorShape object
  """
  if spec.ndim is None and spec.shape is None:
    return tf.TensorShape(None)
  elif spec.shape is not None:
    return tf.TensorShape(spec.shape)
  else:
    shape = [None] * spec.ndim
    for a in spec.axes:
      shape[a] = spec.axes[a]  # Assume that axes is defined
    return tf.TensorShape(shape)


def assert_input_compatibility(input_spec, inputs, layer_name):
  """Checks compatibility between the layer and provided inputs.

  This checks that the tensor(s) `inputs` verify the input assumptions
  of a layer (if any). If not, a clear and actional exception gets raised.

  Args:
      input_spec: An InputSpec instance, list of InputSpec instances, a nested
          structure of InputSpec instances, or None.
      inputs: Input tensor, list of input tensors, or a nested structure of
          input tensors.
      layer_name: String, name of the layer (for error message formatting).

  Raises:
      ValueError: in case of mismatch between
          the provided inputs and the expectations of the layer.
  """
  if not input_spec:
    return

  input_spec = tf.nest.flatten(input_spec)
  if isinstance(inputs, dict):
    # Flatten `inputs` by reference order if input spec names are provided
    names = [spec.name for spec in input_spec]
    if all(names):
      list_inputs = []
      for name in names:
        if name not in inputs:
          raise ValueError('Missing data for input "%s". '
                           'You passed a data dictionary with keys %s. '
                           'Expected the following keys: %s' %
                           (name, list(inputs.keys()), names))
        list_inputs.append(inputs[name])
      inputs = list_inputs

  inputs = tf.nest.flatten(inputs)
  for x in inputs:
    # Having a shape/dtype is the only commonality of the various tensor-like
    # objects that may be passed. The most common kind of invalid type we are
    # guarding for is a Layer instance (Functional API), which does not
    # have a `shape` attribute.
    if not hasattr(x, 'shape'):
      raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))

  if len(inputs) != len(input_spec):
    raise ValueError('Layer ' + layer_name + ' expects ' +
                     str(len(input_spec)) + ' input(s), '
                     'but it received ' + str(len(inputs)) +
                     ' input tensors. Inputs received: ' + str(inputs))
  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    if spec is None:
      continue

    shape = tf.TensorShape(x.shape)
    if shape.rank is None:
      return
    # Check ndim.
    if spec.ndim is not None and not spec.allow_last_axis_squeeze:
      ndim = shape.rank
      if ndim != spec.ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
                         str(ndim) + '. Full shape received: ' +
                         str(tuple(shape)))
    if spec.max_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim > spec.max_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected max_ndim=' + str(spec.max_ndim) +
                         ', found ndim=' + str(ndim))
    if spec.min_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim < spec.min_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         ': expected min_ndim=' + str(spec.min_ndim) +
                         ', found ndim=' + str(ndim) +
                         '. Full shape received: ' +
                         str(tuple(shape)))
    # Check dtype.
    if spec.dtype is not None:
      if x.dtype.name != spec.dtype:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected dtype=' + str(spec.dtype) +
                         ', found dtype=' + str(x.dtype))

    # Check specific shape axes.
    shape_as_list = shape.as_list()
    if spec.axes:
      for axis, value in spec.axes.items():
        if hasattr(value, 'value'):
          value = value.value
        if value is not None and shape_as_list[int(axis)] not in {value, None}:
          raise ValueError(
              'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
              ' incompatible with the layer: expected axis ' + str(axis) +
              ' of input shape to have value ' + str(value) +
              ' but received input with shape ' + display_shape(x.shape))
    # Check shape.
    if spec.shape is not None and shape.rank is not None:
      spec_shape = spec.shape
      if spec.allow_last_axis_squeeze:
        if shape_as_list and shape_as_list[-1] == 1:
          shape_as_list = shape_as_list[:-1]
        if spec_shape and spec_shape[-1] == 1:
          spec_shape = spec_shape[:-1]
      for spec_dim, dim in zip(spec_shape, shape_as_list):
        if spec_dim is not None and dim is not None:
          if spec_dim != dim:
            raise ValueError('Input ' + str(input_index) +
                             ' is incompatible with layer ' + layer_name +
                             ': expected shape=' + str(spec.shape) +
                             ', found shape=' + display_shape(x.shape))


def display_shape(shape):
  return str(tuple(shape.as_list()))


def to_tensor_spec(input_spec, default_dtype=None):
  """Converts a Keras InputSpec object to a TensorSpec."""
  default_dtype = default_dtype or backend.floatx()
  if isinstance(input_spec, InputSpec):
    dtype = input_spec.dtype or default_dtype
    return tf.TensorSpec(to_tensor_shape(input_spec), dtype)
  return tf.TensorSpec(None, default_dtype)

Functions

def assert_input_compatibility(input_spec, inputs, layer_name)

Checks compatibility between the layer and provided inputs.

This checks that the tensor(s) inputs verify the input assumptions of a layer (if any). If not, a clear and actional exception gets raised.

Args

input_spec
An InputSpec instance, list of InputSpec instances, a nested structure of InputSpec instances, or None.
inputs
Input tensor, list of input tensors, or a nested structure of input tensors.
layer_name
String, name of the layer (for error message formatting).

Raises

ValueError
in case of mismatch between the provided inputs and the expectations of the layer.
Expand source code
def assert_input_compatibility(input_spec, inputs, layer_name):
  """Checks compatibility between the layer and provided inputs.

  This checks that the tensor(s) `inputs` verify the input assumptions
  of a layer (if any). If not, a clear and actional exception gets raised.

  Args:
      input_spec: An InputSpec instance, list of InputSpec instances, a nested
          structure of InputSpec instances, or None.
      inputs: Input tensor, list of input tensors, or a nested structure of
          input tensors.
      layer_name: String, name of the layer (for error message formatting).

  Raises:
      ValueError: in case of mismatch between
          the provided inputs and the expectations of the layer.
  """
  if not input_spec:
    return

  input_spec = tf.nest.flatten(input_spec)
  if isinstance(inputs, dict):
    # Flatten `inputs` by reference order if input spec names are provided
    names = [spec.name for spec in input_spec]
    if all(names):
      list_inputs = []
      for name in names:
        if name not in inputs:
          raise ValueError('Missing data for input "%s". '
                           'You passed a data dictionary with keys %s. '
                           'Expected the following keys: %s' %
                           (name, list(inputs.keys()), names))
        list_inputs.append(inputs[name])
      inputs = list_inputs

  inputs = tf.nest.flatten(inputs)
  for x in inputs:
    # Having a shape/dtype is the only commonality of the various tensor-like
    # objects that may be passed. The most common kind of invalid type we are
    # guarding for is a Layer instance (Functional API), which does not
    # have a `shape` attribute.
    if not hasattr(x, 'shape'):
      raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))

  if len(inputs) != len(input_spec):
    raise ValueError('Layer ' + layer_name + ' expects ' +
                     str(len(input_spec)) + ' input(s), '
                     'but it received ' + str(len(inputs)) +
                     ' input tensors. Inputs received: ' + str(inputs))
  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    if spec is None:
      continue

    shape = tf.TensorShape(x.shape)
    if shape.rank is None:
      return
    # Check ndim.
    if spec.ndim is not None and not spec.allow_last_axis_squeeze:
      ndim = shape.rank
      if ndim != spec.ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
                         str(ndim) + '. Full shape received: ' +
                         str(tuple(shape)))
    if spec.max_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim > spec.max_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected max_ndim=' + str(spec.max_ndim) +
                         ', found ndim=' + str(ndim))
    if spec.min_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim < spec.min_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         ': expected min_ndim=' + str(spec.min_ndim) +
                         ', found ndim=' + str(ndim) +
                         '. Full shape received: ' +
                         str(tuple(shape)))
    # Check dtype.
    if spec.dtype is not None:
      if x.dtype.name != spec.dtype:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected dtype=' + str(spec.dtype) +
                         ', found dtype=' + str(x.dtype))

    # Check specific shape axes.
    shape_as_list = shape.as_list()
    if spec.axes:
      for axis, value in spec.axes.items():
        if hasattr(value, 'value'):
          value = value.value
        if value is not None and shape_as_list[int(axis)] not in {value, None}:
          raise ValueError(
              'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
              ' incompatible with the layer: expected axis ' + str(axis) +
              ' of input shape to have value ' + str(value) +
              ' but received input with shape ' + display_shape(x.shape))
    # Check shape.
    if spec.shape is not None and shape.rank is not None:
      spec_shape = spec.shape
      if spec.allow_last_axis_squeeze:
        if shape_as_list and shape_as_list[-1] == 1:
          shape_as_list = shape_as_list[:-1]
        if spec_shape and spec_shape[-1] == 1:
          spec_shape = spec_shape[:-1]
      for spec_dim, dim in zip(spec_shape, shape_as_list):
        if spec_dim is not None and dim is not None:
          if spec_dim != dim:
            raise ValueError('Input ' + str(input_index) +
                             ' is incompatible with layer ' + layer_name +
                             ': expected shape=' + str(spec.shape) +
                             ', found shape=' + display_shape(x.shape))
def display_shape(shape)
Expand source code
def display_shape(shape):
  return str(tuple(shape.as_list()))
def to_tensor_shape(spec)

Returns a tf.TensorShape object that matches the shape specifications.

If the InputSpec's shape or ndim is defined, this method will return a fully or partially-known shape. Otherwise, the returned TensorShape is None.

Args

spec
an InputSpec object.

Returns

a tf.TensorShape object

Expand source code
def to_tensor_shape(spec):
  """Returns a tf.TensorShape object that matches the shape specifications.

  If the InputSpec's shape or ndim is defined, this method will return a fully
  or partially-known shape. Otherwise, the returned TensorShape is None.

  Args:
    spec: an InputSpec object.

  Returns:
    a tf.TensorShape object
  """
  if spec.ndim is None and spec.shape is None:
    return tf.TensorShape(None)
  elif spec.shape is not None:
    return tf.TensorShape(spec.shape)
  else:
    shape = [None] * spec.ndim
    for a in spec.axes:
      shape[a] = spec.axes[a]  # Assume that axes is defined
    return tf.TensorShape(shape)
def to_tensor_spec(input_spec, default_dtype=None)

Converts a Keras InputSpec object to a TensorSpec.

Expand source code
def to_tensor_spec(input_spec, default_dtype=None):
  """Converts a Keras InputSpec object to a TensorSpec."""
  default_dtype = default_dtype or backend.floatx()
  if isinstance(input_spec, InputSpec):
    dtype = input_spec.dtype or default_dtype
    return tf.TensorSpec(to_tensor_shape(input_spec), dtype)
  return tf.TensorSpec(None, default_dtype)

Classes

class InputSpec (dtype=None, shape=None, ndim=None, max_ndim=None, min_ndim=None, axes=None, allow_last_axis_squeeze=False, name=None)

Specifies the rank, dtype and shape of every input to a layer.

Layers can expose (if appropriate) an input_spec attribute: an instance of InputSpec, or a nested structure of InputSpec instances (one per input tensor). These objects enable the layer to run input compatibility checks for input structure, input rank, input shape, and input dtype.

A None entry in a shape is compatible with any dimension, a None shape is compatible with any shape.

Args

dtype
Expected DataType of the input.
shape
Shape tuple, expected shape of the input (may include None for unchecked axes). Includes the batch size.
ndim
Integer, expected rank of the input.
max_ndim
Integer, maximum rank of the input.
min_ndim
Integer, minimum rank of the input.
axes
Dictionary mapping integer axes to a specific dimension value.
allow_last_axis_squeeze
If True, then allow inputs of rank N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the last axis of the spec is 1.
name
Expected key corresponding to this input when passing data as a dictionary.

Example:

class MyLayer(Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
        # and raise an appropriate error message otherwise.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)
Expand source code
class InputSpec(object):
  """Specifies the rank, dtype and shape of every input to a layer.

  Layers can expose (if appropriate) an `input_spec` attribute:
  an instance of `InputSpec`, or a nested structure of `InputSpec` instances
  (one per input tensor). These objects enable the layer to run input
  compatibility checks for input structure, input rank, input shape, and
  input dtype.

  A None entry in a shape is compatible with any dimension,
  a None shape is compatible with any shape.

  Args:
    dtype: Expected DataType of the input.
    shape: Shape tuple, expected shape of the input
      (may include None for unchecked axes). Includes the batch size.
    ndim: Integer, expected rank of the input.
    max_ndim: Integer, maximum rank of the input.
    min_ndim: Integer, minimum rank of the input.
    axes: Dictionary mapping integer axes to
      a specific dimension value.
    allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
      as the last axis of the input is 1, as well as inputs of rank N-1
      as long as the last axis of the spec is 1.
    name: Expected key corresponding to this input when passing data as
      a dictionary.

  Example:

  ```python
  class MyLayer(Layer):
      def __init__(self):
          super(MyLayer, self).__init__()
          # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
          # and raise an appropriate error message otherwise.
          self.input_spec = InputSpec(
              shape=(None, 28, 28, 1),
              allow_last_axis_squeeze=True)
  ```
  """

  def __init__(self,
               dtype=None,
               shape=None,
               ndim=None,
               max_ndim=None,
               min_ndim=None,
               axes=None,
               allow_last_axis_squeeze=False,
               name=None):
    self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
    shape = tf.TensorShape(shape)
    if shape.rank is None:
      shape = None
    else:
      shape = tuple(shape.as_list())
    if shape is not None:
      self.ndim = len(shape)
      self.shape = shape
    else:
      self.ndim = ndim
      self.shape = None
    self.max_ndim = max_ndim
    self.min_ndim = min_ndim
    self.name = name
    self.allow_last_axis_squeeze = allow_last_axis_squeeze
    try:
      axes = axes or {}
      self.axes = {int(k): axes[k] for k in axes}
    except (ValueError, TypeError):
      raise TypeError('The keys in axes must be integers.')

    if self.axes and (self.ndim is not None or self.max_ndim is not None):
      max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
      max_axis = max(self.axes)
      if max_axis > max_dim:
        raise ValueError('Axis {} is greater than the maximum allowed value: {}'
                         .format(max_axis, max_dim))

  def __repr__(self):
    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
            ('shape=' + str(self.shape)) if self.shape else '',
            ('ndim=' + str(self.ndim)) if self.ndim else '',
            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
            ('axes=' + str(self.axes)) if self.axes else '']
    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)

  def get_config(self):
    return {
        'dtype': self.dtype,
        'shape': self.shape,
        'ndim': self.ndim,
        'max_ndim': self.max_ndim,
        'min_ndim': self.min_ndim,
        'axes': self.axes}

  @classmethod
  def from_config(cls, config):
    return cls(**config)

Static methods

def from_config(config)
Expand source code
@classmethod
def from_config(cls, config):
  return cls(**config)

Methods

def get_config(self)
Expand source code
def get_config(self):
  return {
      'dtype': self.dtype,
      'shape': self.shape,
      'ndim': self.ndim,
      'max_ndim': self.max_ndim,
      'min_ndim': self.min_ndim,
      'axes': self.axes}