Module keras.api.keras.utils

Public API for tf.keras.utils namespace.

Expand source code
# This file is MACHINE GENERATED! Do not edit.
# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
"""Public API for tf.keras.utils namespace.
"""

from __future__ import print_function as _print_function

import sys as _sys

from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from keras.preprocessing.image import save_img
from keras.utils.data_utils import GeneratorEnqueuer
from keras.utils.data_utils import OrderedEnqueuer
from keras.utils.data_utils import Sequence
from keras.utils.data_utils import SequenceEnqueuer
from keras.utils.data_utils import get_file
from keras.utils.generic_utils import CustomObjectScope
from keras.utils.generic_utils import CustomObjectScope as custom_object_scope
from keras.utils.generic_utils import Progbar
from keras.utils.generic_utils import deserialize_keras_object
from keras.utils.generic_utils import get_custom_objects
from keras.utils.generic_utils import get_registered_name
from keras.utils.generic_utils import get_registered_object
from keras.utils.generic_utils import register_keras_serializable
from keras.utils.generic_utils import serialize_keras_object
from keras.utils.layer_utils import get_source_inputs
from keras.utils.np_utils import normalize
from keras.utils.np_utils import to_categorical
from keras.utils.vis_utils import model_to_dot
from keras.utils.vis_utils import plot_model

del _print_function

from tensorflow.python.util import module_wrapper as _module_wrapper

if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
  _sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
      _sys.modules[__name__], "keras.utils", public_apis=None, deprecation=True,
      has_lite=False)

Functions

def array_to_img(x, data_format=None, scale=True, dtype=None)

Converts a 3D Numpy array to a PIL Image instance.

Usage:

from PIL import Image
img = np.random.random(size=(100, 100, 3))
pil_img = tf.keras.preprocessing.image.array_to_img(img)

Args

x
Input data, in any form that can be converted to a Numpy array.
data_format
Image data format, can be either "channels_first" or "channels_last". Defaults to None, in which case the global setting tf.keras.backend.image_data_format() is used (unless you changed it, it defaults to "channels_last").
scale
Whether to rescale the image such that minimum and maximum values are 0 and 255 respectively. Defaults to True.
dtype
Dtype to use. Default to None, in which case the global setting

tf.keras.backend.floatx() is used (unless you changed it, it defaults to "float32")

Returns

A PIL Image instance.

Raises

ImportError
if PIL is not available.
ValueError
if invalid x or data_format is passed.
Expand source code
@keras_export('keras.utils.array_to_img',
              'keras.preprocessing.image.array_to_img')
def array_to_img(x, data_format=None, scale=True, dtype=None):
  """Converts a 3D Numpy array to a PIL Image instance.

  Usage:

  ```python
  from PIL import Image
  img = np.random.random(size=(100, 100, 3))
  pil_img = tf.keras.preprocessing.image.array_to_img(img)
  ```


  Args:
      x: Input data, in any form that can be converted to a Numpy array.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, in which case the global setting
        `tf.keras.backend.image_data_format()` is used (unless you changed it,
        it defaults to "channels_last").
      scale: Whether to rescale the image such that minimum and maximum values
        are 0 and 255 respectively. Defaults to `True`.
      dtype: Dtype to use. Default to `None`, in which case the global setting
      `tf.keras.backend.floatx()` is used (unless you changed it, it defaults
      to "float32")

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if invalid `x` or `data_format` is passed.
  """

  if data_format is None:
    data_format = backend.image_data_format()
  kwargs = {}
  if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
    if dtype is None:
      dtype = backend.floatx()
    kwargs['dtype'] = dtype
  return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object')

Turns the serialized form of a Keras object back into an actual object.

This function is for mid-level library implementers rather than end users.

Importantly, this utility requires you to provide the dict of module_objects to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g. keras.layers.deserialize(config), deserialize()(config), etc.

Calling deserialize_keras_object() while underneath the SharedObjectLoadingScope context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object.

Args

identifier
the serialized form of the object.
module_objects
A dictionary of built-in objects to look the name up in. Generally, module_objects is provided by midlevel library implementers.
custom_objects
A dictionary of custom objects to look the name up in. Generally, custom_objects is provided by the end user.
printable_module_name
A human-readable string representing the type of the object. Printed in case of exception.

Returns

The deserialized object. Example:

A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such:

def deserialize(config, custom_objects=None):
   return deserialize_keras_object(
     identifier,
     module_objects=globals(),
     custom_objects=custom_objects,
     name="MyObjectType",
   )

This is how e.g. keras.layers.deserialize() is implemented.

Expand source code
@keras_export('keras.utils.deserialize_keras_object')
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
  """Turns the serialized form of a Keras object back into an actual object.

  This function is for mid-level library implementers rather than end users.

  Importantly, this utility requires you to provide the dict of `module_objects`
  to use for looking up the object config; this is not populated by default.
  If you need a deserialization utility that has preexisting knowledge of
  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
  `keras.metrics.deserialize(config)`, etc.

  Calling `deserialize_keras_object` while underneath the
  `SharedObjectLoadingScope` context manager will cause any already-seen shared
  objects to be returned as-is rather than creating a new object.

  Args:
    identifier: the serialized form of the object.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.
    printable_module_name: A human-readable string representing the type of the
      object. Printed in case of exception.

  Returns:
    The deserialized object.

  Example:

  A mid-level library implementer might want to implement a utility for
  retrieving an object from its config, as such:

  ```python
  def deserialize(config, custom_objects=None):
     return deserialize_keras_object(
       identifier,
       module_objects=globals(),
       custom_objects=custom_objects,
       name="MyObjectType",
     )
  ```

  This is how e.g. `keras.layers.deserialize()` is implemented.
  """
  if identifier is None:
    return None

  if isinstance(identifier, dict):
    # In this case we are dealing with a Keras config dictionary.
    config = identifier
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
        config, module_objects, custom_objects, printable_module_name)

    # If this object has already been loaded (i.e. it's shared between multiple
    # objects), return the already-loaded object.
    shared_object_id = config.get(SHARED_OBJECT_KEY)
    shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
    if shared_object is not None:
      return shared_object

    if hasattr(cls, 'from_config'):
      arg_spec = tf_inspect.getfullargspec(cls.from_config)
      custom_objects = custom_objects or {}

      if 'custom_objects' in arg_spec.args:
        deserialized_obj = cls.from_config(
            cls_config,
            custom_objects=dict(
                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                list(custom_objects.items())))
      else:
        with CustomObjectScope(custom_objects):
          deserialized_obj = cls.from_config(cls_config)
    else:
      # Then `cls` may be a function returning a class.
      # in this case by convention `config` holds
      # the kwargs of the function.
      custom_objects = custom_objects or {}
      with CustomObjectScope(custom_objects):
        deserialized_obj = cls(**cls_config)

    # Add object to shared objects, in case we find it referenced again.
    _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

    return deserialized_obj

  elif isinstance(identifier, str):
    object_name = identifier
    if custom_objects and object_name in custom_objects:
      obj = custom_objects.get(object_name)
    elif object_name in _GLOBAL_CUSTOM_OBJECTS:
      obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
    else:
      obj = module_objects.get(object_name)
      if obj is None:
        raise ValueError(
            'Unknown {}: {}. Please ensure this object is '
            'passed to the `custom_objects` argument. See '
            'https://www.tensorflow.org/guide/keras/save_and_serialize'
            '#registering_the_custom_object for details.'
            .format(printable_module_name, object_name))

    # Classes passed by name are instantiated with no args, functions are
    # returned as-is.
    if tf_inspect.isclass(obj):
      return obj()
    return obj
  elif tf_inspect.isfunction(identifier):
    # If a function has already been deserialized, return as is.
    return identifier
  else:
    raise ValueError('Could not interpret serialized %s: %s' %
                     (printable_module_name, identifier))
def get_custom_objects()

Retrieves a live reference to the global dictionary of custom objects.

Updating and clearing custom objects using CustomObjectScope is preferred, but get_custom_objects() can be used to directly access the current collection of custom objects.

Example:

get_custom_objects().clear()
get_custom_objects()['MyObject'] = MyObject

Returns

Global dictionary of names to classes (_GLOBAL_CUSTOM_OBJECTS).

Expand source code
@keras_export('keras.utils.get_custom_objects')
def get_custom_objects():
  """Retrieves a live reference to the global dictionary of custom objects.

  Updating and clearing custom objects using `custom_object_scope`
  is preferred, but `get_custom_objects` can
  be used to directly access the current collection of custom objects.

  Example:

  ```python
  get_custom_objects().clear()
  get_custom_objects()['MyObject'] = MyObject
  ```

  Returns:
      Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
  """
  return _GLOBAL_CUSTOM_OBJECTS
def get_file(fname=None, origin=None, untar=False, md5_hash=None, file_hash=None, cache_subdir='datasets', hash_algorithm='auto', extract=False, archive_format='auto', cache_dir=None)

Downloads a file from a URL if it not already in the cache.

By default the file at the url origin is downloaded to the cache_dir ~/.keras, placed in the cache_subdir datasets, and given the filename fname. The final location of a file example.txt would therefore be ~/.keras/datasets/example.txt.

Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs shasum and sha256sum can compute the hash.

Example:

path_to_downloaded_file = tf.keras.utils.get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar=True)

Args

fname
Name of the file. If an absolute path /path/to/file.txt is specified the file will be saved at that location. If None, the name of the file at origin will be used.
origin
Original URL of the file.
untar
Deprecated in favor of extract argument. boolean, whether the file should be decompressed
md5_hash
Deprecated in favor of file_hash argument. md5 hash of the file for verification
file_hash
The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported.
cache_subdir
Subdirectory under the Keras cache dir where the file is saved. If an absolute path /path/to/folder is specified the file will be saved at that location.
hash_algorithm
Select the hash algorithm to verify the file. options are 'md5', 'sha256', and 'auto'. The default 'auto' detects the hash algorithm in use.
extract
True tries extracting the file as an Archive, like tar or zip.
archive_format
Archive format to try for extracting the file. Options are 'auto', 'tar', 'zip', and None. 'tar' includes tar, tar.gz, and tar.bz files. The default 'auto' corresponds to ['tar', 'zip']. None or an empty list will return no matches found.
cache_dir
Location to store cached files, when None it defaults to the default directory ~/.keras/.

Returns

Path to the downloaded file

Expand source code
@keras_export('keras.utils.get_file')
def get_file(fname=None,
             origin=None,
             untar=False,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
  """Downloads a file from a URL if it not already in the cache.

  By default the file at the url `origin` is downloaded to the
  cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
  and given the filename `fname`. The final location of a file
  `example.txt` would therefore be `~/.keras/datasets/example.txt`.

  Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
  Passing a hash will verify the file after download. The command line
  programs `shasum` and `sha256sum` can compute the hash.

  Example:

  ```python
  path_to_downloaded_file = tf.keras.utils.get_file(
      "flower_photos",
      "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
      untar=True)
  ```

  Args:
      fname: Name of the file. If an absolute path `/path/to/file.txt` is
          specified the file will be saved at that location. If `None`, the
          name of the file at `origin` will be used.
      origin: Original URL of the file.
      untar: Deprecated in favor of `extract` argument.
          boolean, whether the file should be decompressed
      md5_hash: Deprecated in favor of `file_hash` argument.
          md5 hash of the file for verification
      file_hash: The expected hash string of the file after download.
          The sha256 and md5 hash algorithms are both supported.
      cache_subdir: Subdirectory under the Keras cache dir where the file is
          saved. If an absolute path `/path/to/folder` is
          specified the file will be saved at that location.
      hash_algorithm: Select the hash algorithm to verify the file.
          options are `'md5'`, `'sha256'`, and `'auto'`.
          The default 'auto' detects the hash algorithm in use.
      extract: True tries extracting the file as an Archive, like tar or zip.
      archive_format: Archive format to try for extracting the file.
          Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
          `'tar'` includes tar, tar.gz, and tar.bz files.
          The default `'auto'` corresponds to `['tar', 'zip']`.
          None or an empty list will return no matches found.
      cache_dir: Location to store cached files, when None it
          defaults to the default directory `~/.keras/`.

  Returns:
      Path to the downloaded file
  """
  if origin is None:
    raise ValueError('Please specify the "origin" argument (URL of the file '
                     'to download).')

  if cache_dir is None:
    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
  if md5_hash is not None and file_hash is None:
    file_hash = md5_hash
    hash_algorithm = 'md5'
  datadir_base = os.path.expanduser(cache_dir)
  if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join('/tmp', '.keras')
  datadir = os.path.join(datadir_base, cache_subdir)
  _makedirs_exist_ok(datadir)

  fname = path_to_string(fname)
  if not fname:
    fname = os.path.basename(urlsplit(origin).path)
    if not fname:
      raise ValueError("Invalid origin '{}'".format(origin))

  if untar:
    if fname.endswith('.tar.gz'):
      fname = pathlib.Path(fname)
      # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
      # considers it as 2 suffixes.
      fname = fname.with_suffix('').with_suffix('')
      fname = str(fname)
    untar_fpath = os.path.join(datadir, fname)
    fpath = untar_fpath + '.tar.gz'
  else:
    fpath = os.path.join(datadir, fname)

  download = False
  if os.path.exists(fpath):
    # File found; verify integrity if a hash was provided.
    if file_hash is not None:
      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
        print('A local file was found, but it seems to be '
              'incomplete or outdated because the ' + hash_algorithm +
              ' file hash does not match the original value of ' + file_hash +
              ' so we will re-download the data.')
        download = True
  else:
    download = True

  if download:
    print('Downloading data from', origin)

    class ProgressTracker(object):
      # Maintain progbar for the lifetime of download.
      # This design was chosen for Python 2.7 compatibility.
      progbar = None

    def dl_progress(count, block_size, total_size):
      if ProgressTracker.progbar is None:
        if total_size == -1:
          total_size = None
        ProgressTracker.progbar = Progbar(total_size)
      else:
        ProgressTracker.progbar.update(count * block_size)

    error_msg = 'URL fetch failure on {}: {} -- {}'
    try:
      try:
        urlretrieve(origin, fpath, dl_progress)
      except urllib.error.HTTPError as e:
        raise Exception(error_msg.format(origin, e.code, e.msg))
      except urllib.error.URLError as e:
        raise Exception(error_msg.format(origin, e.errno, e.reason))
    except (Exception, KeyboardInterrupt) as e:
      if os.path.exists(fpath):
        os.remove(fpath)
      raise
    ProgressTracker.progbar = None

  if untar:
    if not os.path.exists(untar_fpath):
      _extract_archive(fpath, datadir, archive_format='tar')
    return untar_fpath

  if extract:
    _extract_archive(fpath, datadir, archive_format)

  return fpath
def get_registered_name(obj)

Returns the name registered to an object within the Keras framework.

This function is part of the Keras serialization and deserialization framework. It maps objects to the string names associated with those objects for serialization/deserialization.

Args

obj
The object to look up.

Returns

The name associated with the object, or the default Python name if the object is not registered.

Expand source code
@keras_export('keras.utils.get_registered_name')
def get_registered_name(obj):
  """Returns the name registered to an object within the Keras framework.

  This function is part of the Keras serialization and deserialization
  framework. It maps objects to the string names associated with those objects
  for serialization/deserialization.

  Args:
    obj: The object to look up.

  Returns:
    The name associated with the object, or the default Python name if the
      object is not registered.
  """
  if obj in _GLOBAL_CUSTOM_NAMES:
    return _GLOBAL_CUSTOM_NAMES[obj]
  else:
    return obj.__name__
def get_registered_object(name, custom_objects=None, module_objects=None)

Returns the class associated with name if it is registered with Keras.

This function is part of the Keras serialization and deserialization framework. It maps strings to the objects associated with them for serialization/deserialization.

Example:

def from_config(cls, config, custom_objects=None):
  if 'my_custom_object_name' in config:
    config['hidden_cls'] = tf.keras.utils.get_registered_object(
        config['my_custom_object_name'], custom_objects=custom_objects)

Args

name
The name to look up.
custom_objects
A dictionary of custom objects to look the name up in. Generally, custom_objects is provided by the user.
module_objects
A dictionary of custom objects to look the name up in. Generally, module_objects is provided by midlevel library implementers.

Returns

An instantiable class associated with 'name', or None if no such class exists.

Expand source code
@keras_export('keras.utils.get_registered_object')
def get_registered_object(name, custom_objects=None, module_objects=None):
  """Returns the class associated with `name` if it is registered with Keras.

  This function is part of the Keras serialization and deserialization
  framework. It maps strings to the objects associated with them for
  serialization/deserialization.

  Example:
  ```
  def from_config(cls, config, custom_objects=None):
    if 'my_custom_object_name' in config:
      config['hidden_cls'] = tf.keras.utils.get_registered_object(
          config['my_custom_object_name'], custom_objects=custom_objects)
  ```

  Args:
    name: The name to look up.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, custom_objects is provided by the user.
    module_objects: A dictionary of custom objects to look the name up in.
      Generally, module_objects is provided by midlevel library implementers.

  Returns:
    An instantiable class associated with 'name', or None if no such class
      exists.
  """
  if name in _GLOBAL_CUSTOM_OBJECTS:
    return _GLOBAL_CUSTOM_OBJECTS[name]
  elif custom_objects and name in custom_objects:
    return custom_objects[name]
  elif module_objects and name in module_objects:
    return module_objects[name]
  return None
def get_source_inputs(tensor, layer=None, node_index=None)

Returns the list of input tensors necessary to compute tensor.

Output will always be a list of tensors (potentially with 1 element).

Args

tensor
The tensor to start from.
layer
Origin layer of the tensor. Will be determined via tensor._keras_history if not provided.
node_index
Origin node index of the tensor.

Returns

List of input tensors.

Expand source code
@keras_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
  """Returns the list of input tensors necessary to compute `tensor`.

  Output will always be a list of tensors
  (potentially with 1 element).

  Args:
      tensor: The tensor to start from.
      layer: Origin layer of the tensor. Will be
          determined via tensor._keras_history if not provided.
      node_index: Origin node index of the tensor.

  Returns:
      List of input tensors.
  """
  if not hasattr(tensor, '_keras_history'):
    return tensor

  if layer is None or node_index:
    layer, node_index, _ = tensor._keras_history
  if not layer._inbound_nodes:
    return [tensor]
  else:
    node = layer._inbound_nodes[node_index]
    if node.is_input:
      # Reached an Input layer, stop recursion.
      return tf.nest.flatten(node.input_tensors)
    else:
      source_tensors = []
      for layer, node_index, _, tensor in node.iterate_inbound():
        previous_sources = get_source_inputs(tensor, layer, node_index)
        # Avoid input redundancy.
        for x in previous_sources:
          if all(x is not t for t in source_tensors):
            source_tensors.append(x)
      return source_tensors
def img_to_array(img, data_format=None, dtype=None)

Converts a PIL Image instance to a Numpy array.

Usage:

from PIL import Image
img_data = np.random.random(size=(100, 100, 3))
img = tf.keras.preprocessing.image.array_to_img(img_data)
array = tf.keras.preprocessing.image.img_to_array(img)

Args

img
Input PIL Image instance.
data_format
Image data format, can be either "channels_first" or "channels_last". Defaults to None, in which case the global setting tf.keras.backend.image_data_format() is used (unless you changed it, it defaults to "channels_last").
dtype
Dtype to use. Default to None, in which case the global setting

tf.keras.backend.floatx() is used (unless you changed it, it defaults to "float32")

Returns

A 3D Numpy array.

Raises

ValueError
if invalid img or data_format is passed.
Expand source code
@keras_export('keras.utils.img_to_array',
              'keras.preprocessing.image.img_to_array')
def img_to_array(img, data_format=None, dtype=None):
  """Converts a PIL Image instance to a Numpy array.

  Usage:

  ```python
  from PIL import Image
  img_data = np.random.random(size=(100, 100, 3))
  img = tf.keras.preprocessing.image.array_to_img(img_data)
  array = tf.keras.preprocessing.image.img_to_array(img)
  ```


  Args:
      img: Input PIL Image instance.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, in which case the global setting
        `tf.keras.backend.image_data_format()` is used (unless you changed it,
        it defaults to "channels_last").
      dtype: Dtype to use. Default to `None`, in which case the global setting
      `tf.keras.backend.floatx()` is used (unless you changed it, it defaults
      to "float32")

  Returns:
      A 3D Numpy array.

  Raises:
      ValueError: if invalid `img` or `data_format` is passed.
  """

  if data_format is None:
    data_format = backend.image_data_format()
  kwargs = {}
  if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
    if dtype is None:
      dtype = backend.floatx()
    kwargs['dtype'] = dtype
  return image.img_to_array(img, data_format=data_format, **kwargs)
def load_img(path, grayscale=False, color_mode='rgb', target_size=None, interpolation='nearest')

Loads an image into PIL format.

Usage:

image = tf.keras.preprocessing.image.load_img(image_path)
input_arr = tf.keras.preprocessing.image.img_to_array(image)
input_arr = np.array([input_arr])  # Convert single image to a batch.
predictions = model.predict(input_arr)

Args

path
Path to image file.
grayscale
DEPRECATED use color_mode="grayscale".
color_mode
One of "grayscale", "rgb", "rgba". Default: "rgb". The desired image format.
target_size
Either None (default to original size) or tuple of ints (img_height, img_width).
interpolation
Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used.

Returns

A PIL Image instance.

Raises

ImportError
if PIL is not available.
ValueError
if interpolation method is not supported.
Expand source code
@keras_export('keras.utils.load_img',
              'keras.preprocessing.image.load_img')
def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
             interpolation='nearest'):
  """Loads an image into PIL format.

  Usage:

  ```
  image = tf.keras.preprocessing.image.load_img(image_path)
  input_arr = tf.keras.preprocessing.image.img_to_array(image)
  input_arr = np.array([input_arr])  # Convert single image to a batch.
  predictions = model.predict(input_arr)
  ```

  Args:
      path: Path to image file.
      grayscale: DEPRECATED use `color_mode="grayscale"`.
      color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb".
          The desired image format.
      target_size: Either `None` (default to original size)
          or tuple of ints `(img_height, img_width)`.
      interpolation: Interpolation method used to resample the image if the
          target size is different from that of the loaded image.
          Supported methods are "nearest", "bilinear", and "bicubic".
          If PIL version 1.1.3 or newer is installed, "lanczos" is also
          supported. If PIL version 3.4.0 or newer is installed, "box" and
          "hamming" are also supported. By default, "nearest" is used.

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if interpolation method is not supported.
  """
  return image.load_img(path, grayscale=grayscale, color_mode=color_mode,
                        target_size=target_size, interpolation=interpolation)
def model_to_dot(model, show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, subgraph=False, layer_range=None)

Convert a Keras model to dot format.

Args

model
A Keras model instance.
show_shapes
whether to display shape information.
show_dtype
whether to display layer dtypes.
show_layer_names
whether to display layer names.
rankdir
rankdir argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.
expand_nested
whether to expand nested models into clusters.
dpi
Dots per inch.
subgraph
whether to return a pydot.Cluster instance.
layer_range
input of list containing two str items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the pydot.Dot will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to layer_range[0] and the end predicate will be the last element it matches to layer_range[1]. By default None which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete.

Returns

A pydot.Dot instance representing the Keras model or a pydot.Cluster instance representing nested model if subgraph=True.

Raises

ImportError
if graphviz or pydot are not available.
Expand source code
@keras_export('keras.utils.model_to_dot')
def model_to_dot(model,
                 show_shapes=False,
                 show_dtype=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False,
                 layer_range=None):
  """Convert a Keras model to dot format.

  Args:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.
    layer_range: input of `list` containing two `str` items, which is the
        starting layer name and ending layer name (both inclusive) indicating
        the range of layers for which the `pydot.Dot` will be generated. It
        also accepts regex patterns instead of exact name. In such case, start
        predicate will be the first element it matches to `layer_range[0]`
        and the end predicate will be the last element it matches to
        `layer_range[1]`. By default `None` which considers all layers of
        model. Note that you must pass range such that the resultant subgraph
        must be complete.

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
  from keras.layers import wrappers
  from keras.engine import sequential
  from keras.engine import functional

  if not check_pydot():
    message = (
        'You must install pydot (`pip install pydot`) '
        'and install graphviz '
        '(see instructions at https://graphviz.gitlab.io/download/) ',
        'for plot_model/model_to_dot to work.')
    if 'IPython.core.magics.namespace' in sys.modules:
      # We don't raise an exception here in order to avoid crashing notebook
      # tests where graphviz is not available.
      print(message)
      return
    else:
      raise ImportError(message)

  if subgraph:
    dot = pydot.Cluster(style='dashed', graph_name=model.name)
    dot.set('label', model.name)
    dot.set('labeljust', 'l')
  else:
    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set('dpi', dpi)
    dot.set_node_defaults(shape='record')

  if layer_range:
    if len(layer_range) != 2:
      raise ValueError('layer_range must be of shape (2,)')
    if (not isinstance(layer_range[0], str) or
        not isinstance(layer_range[1], str)):
      raise ValueError('layer_range should contain string type only')
    layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
    if layer_range[0] < 0 or layer_range[1] > len(model.layers):
      raise ValueError('Both values in layer_range should be in',
                       'range (%d, %d)' % (0, len(model.layers)))

  sub_n_first_node = {}
  sub_n_last_node = {}
  sub_w_first_node = {}
  sub_w_last_node = {}

  layers = model.layers
  if not model._is_graph_network:
    node = pydot.Node(str(id(model)), label=model.name)
    dot.add_node(node)
    return dot
  elif isinstance(model, sequential.Sequential):
    if not model.built:
      model.build()
    layers = super(sequential.Sequential, model).layers

  # Create graph nodes.
  for i, layer in enumerate(layers):
    if (layer_range) and (i < layer_range[0] or i > layer_range[1]):
      continue

    layer_id = str(id(layer))

    # Append a wrapped layer's label to node's label, if it exists.
    layer_name = layer.name
    class_name = layer.__class__.__name__

    if isinstance(layer, wrappers.Wrapper):
      if expand_nested and isinstance(layer.layer,
                                      functional.Functional):
        submodel_wrapper = model_to_dot(
            layer.layer,
            show_shapes,
            show_dtype,
            show_layer_names,
            rankdir,
            expand_nested,
            subgraph=True)
        # sub_w : submodel_wrapper
        sub_w_nodes = submodel_wrapper.get_nodes()
        sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
        sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
        dot.add_subgraph(submodel_wrapper)
      else:
        layer_name = '{}({})'.format(layer_name, layer.layer.name)
        child_class_name = layer.layer.__class__.__name__
        class_name = '{}({})'.format(class_name, child_class_name)

    if expand_nested and isinstance(layer, functional.Functional):
      submodel_not_wrapper = model_to_dot(
          layer,
          show_shapes,
          show_dtype,
          show_layer_names,
          rankdir,
          expand_nested,
          subgraph=True)
      # sub_n : submodel_not_wrapper
      sub_n_nodes = submodel_not_wrapper.get_nodes()
      sub_n_first_node[layer.name] = sub_n_nodes[0]
      sub_n_last_node[layer.name] = sub_n_nodes[-1]
      dot.add_subgraph(submodel_not_wrapper)

    # Create node's label.
    if show_layer_names:
      label = '{}: {}'.format(layer_name, class_name)
    else:
      label = class_name

    # Rebuild the label as a table including the layer's dtype.
    if show_dtype:

      def format_dtype(dtype):
        if dtype is None:
          return '?'
        else:
          return str(dtype)

      label = '%s|%s' % (label, format_dtype(layer.dtype))

    # Rebuild the label as a table including input/output shapes.
    if show_shapes:

      def format_shape(shape):
        return str(shape).replace(str(None), 'None')

      try:
        outputlabels = format_shape(layer.output_shape)
      except AttributeError:
        outputlabels = '?'
      if hasattr(layer, 'input_shape'):
        inputlabels = format_shape(layer.input_shape)
      elif hasattr(layer, 'input_shapes'):
        inputlabels = ', '.join(
            [format_shape(ishape) for ishape in layer.input_shapes])
      else:
        inputlabels = '?'
      label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                     inputlabels,
                                                     outputlabels)

    if not expand_nested or not isinstance(
        layer, functional.Functional):
      node = pydot.Node(layer_id, label=label)
      dot.add_node(node)

  # Connect nodes with edges.
  for i, layer in enumerate(layers):
    if (layer_range) and (i <= layer_range[0] or i > layer_range[1]):
      continue
    layer_id = str(id(layer))
    for i, node in enumerate(layer._inbound_nodes):
      node_key = layer.name + '_ib-' + str(i)
      if node_key in model._network_nodes:
        for inbound_layer in tf.nest.flatten(node.inbound_layers):
          inbound_layer_id = str(id(inbound_layer))
          if not expand_nested:
            assert dot.get_node(inbound_layer_id)
            assert dot.get_node(layer_id)
            add_edge(dot, inbound_layer_id, layer_id)
          else:
            # if inbound_layer is not Model or wrapped Model
            if (not isinstance(inbound_layer,
                               functional.Functional) and
                not is_wrapped_model(inbound_layer)):
              # if current layer is not Model or wrapped Model
              if (not isinstance(layer, functional.Functional) and
                  not is_wrapped_model(layer)):
                assert dot.get_node(inbound_layer_id)
                assert dot.get_node(layer_id)
                add_edge(dot, inbound_layer_id, layer_id)
              # if current layer is Model
              elif isinstance(layer, functional.Functional):
                add_edge(dot, inbound_layer_id,
                         sub_n_first_node[layer.name].get_name())
              # if current layer is wrapped Model
              elif is_wrapped_model(layer):
                add_edge(dot, inbound_layer_id, layer_id)
                name = sub_w_first_node[layer.layer.name].get_name()
                add_edge(dot, layer_id, name)
            # if inbound_layer is Model
            elif isinstance(inbound_layer, functional.Functional):
              name = sub_n_last_node[inbound_layer.name].get_name()
              if isinstance(layer, functional.Functional):
                output_name = sub_n_first_node[layer.name].get_name()
                add_edge(dot, name, output_name)
              else:
                add_edge(dot, name, layer_id)
            # if inbound_layer is wrapped Model
            elif is_wrapped_model(inbound_layer):
              inbound_layer_name = inbound_layer.layer.name
              add_edge(dot,
                       sub_w_last_node[inbound_layer_name].get_name(),
                       layer_id)
  return dot
def normalize(x, axis=-1, order=2)

Normalizes a Numpy array.

Args

x
Numpy array to normalize.
axis
axis along which to normalize.
order
Normalization order (e.g. order=2 for L2 norm).

Returns

A normalized copy of the array.

Expand source code
@keras_export('keras.utils.normalize')
def normalize(x, axis=-1, order=2):
  """Normalizes a Numpy array.

  Args:
      x: Numpy array to normalize.
      axis: axis along which to normalize.
      order: Normalization order (e.g. `order=2` for L2 norm).

  Returns:
      A normalized copy of the array.
  """
  l2 = np.atleast_1d(np.linalg.norm(x, order, axis))
  l2[l2 == 0] = 1
  return x / np.expand_dims(l2, axis)
def plot_model(model, to_file='model.png', show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, layer_range=None)

Converts a Keras model to dot format and save to a file.

Example:

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

Args

model
A Keras model instance
to_file
File name of the plot image.
show_shapes
whether to display shape information.
show_dtype
whether to display layer dtypes.
show_layer_names
whether to display layer names.
rankdir
rankdir argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.
expand_nested
Whether to expand nested models into clusters.
dpi
Dots per inch.
layer_range
input of list containing two str items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the plot will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to layer_range[0] and the end predicate will be the last element it matches to layer_range[1]. By default None which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete.

Returns

A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.

Expand source code
@keras_export('keras.utils.plot_model')
def plot_model(model,
               to_file='model.png',
               show_shapes=False,
               show_dtype=False,
               show_layer_names=True,
               rankdir='TB',
               expand_nested=False,
               dpi=96,
               layer_range=None):
  """Converts a Keras model to dot format and save to a file.

  Example:

  ```python
  input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
  x = tf.keras.layers.Embedding(
      output_dim=512, input_dim=10000, input_length=100)(input)
  x = tf.keras.layers.LSTM(32)(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
  model = tf.keras.Model(inputs=[input], outputs=[output])
  dot_img_file = '/tmp/model_1.png'
  tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
  ```

  Args:
    model: A Keras model instance
    to_file: File name of the plot image.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: Whether to expand nested models into clusters.
    dpi: Dots per inch.
    layer_range: input of `list` containing two `str` items, which is the
        starting layer name and ending layer name (both inclusive) indicating
        the range of layers for which the plot will be generated. It also
        accepts regex patterns instead of exact name. In such case, start
        predicate will be the first element it matches to `layer_range[0]`
        and the end predicate will be the last element it matches to
        `layer_range[1]`. By default `None` which considers all layers of
        model. Note that you must pass range such that the resultant subgraph
        must be complete.

  Returns:
    A Jupyter notebook Image object if Jupyter is installed.
    This enables in-line display of the model plots in notebooks.
  """
  dot = model_to_dot(
      model,
      show_shapes=show_shapes,
      show_dtype=show_dtype,
      show_layer_names=show_layer_names,
      rankdir=rankdir,
      expand_nested=expand_nested,
      dpi=dpi,
      layer_range=layer_range)
  to_file = path_to_string(to_file)
  if dot is None:
    return
  _, extension = os.path.splitext(to_file)
  if not extension:
    extension = 'png'
  else:
    extension = extension[1:]
  # Save image to disk.
  dot.write(to_file, format=extension)
  # Return the image as a Jupyter Image object, to be displayed in-line.
  # Note that we cannot easily detect whether the code is running in a
  # notebook, and thus we always return the Image if Jupyter is available.
  if extension != 'pdf':
    try:
      from IPython import display
      return display.Image(filename=to_file)
    except ImportError:
      pass
def register_keras_serializable(package='Custom', name=None)

Registers an object with the Keras serialization framework.

This decorator injects the decorated class or function into the Keras custom object dictionary, so that it can be serialized and deserialized without needing an entry in the user-provided custom object dict. It also injects a function that Keras will call to get the object's serializable string key.

Note that to be serialized and deserialized, classes must implement the get_config() method. Functions do not have this requirement.

The object will be registered under the key 'package>name' where name, defaults to the object name if not passed.

Args

package
The package that this class belongs to.
name
The name to serialize this class under in this package. If None, the class' name will be used.

Returns

A decorator that registers the decorated class with the passed names.

Expand source code
@keras_export('keras.utils.register_keras_serializable')
def register_keras_serializable(package='Custom', name=None):
  """Registers an object with the Keras serialization framework.

  This decorator injects the decorated class or function into the Keras custom
  object dictionary, so that it can be serialized and deserialized without
  needing an entry in the user-provided custom object dict. It also injects a
  function that Keras will call to get the object's serializable string key.

  Note that to be serialized and deserialized, classes must implement the
  `get_config()` method. Functions do not have this requirement.

  The object will be registered under the key 'package>name' where `name`,
  defaults to the object name if not passed.

  Args:
    package: The package that this class belongs to.
    name: The name to serialize this class under in this package. If None, the
      class' name will be used.

  Returns:
    A decorator that registers the decorated class with the passed names.
  """

  def decorator(arg):
    """Registers a class with the Keras serialization framework."""
    class_name = name if name is not None else arg.__name__
    registered_name = package + '>' + class_name

    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
      raise ValueError(
          'Cannot register a class that does not have a get_config() method.')

    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
      raise ValueError(
          '%s has already been registered to %s' %
          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))

    if arg in _GLOBAL_CUSTOM_NAMES:
      raise ValueError('%s has already been registered to %s' %
                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
    _GLOBAL_CUSTOM_NAMES[arg] = registered_name

    return arg

  return decorator
def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs)

Saves an image stored as a Numpy array to a path or file object.

Args

path
Path or file object.
x
Numpy array.
data_format
Image data format, either "channels_first" or "channels_last".
file_format
Optional file format override. If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used.
scale
Whether to rescale image values to be within [0, 255].
**kwargs
Additional keyword arguments passed to PIL.Image.save().
Expand source code
@keras_export('keras.utils.save_img',
              'keras.preprocessing.image.save_img')
def save_img(path,
             x,
             data_format=None,
             file_format=None,
             scale=True,
             **kwargs):
  """Saves an image stored as a Numpy array to a path or file object.

  Args:
      path: Path or file object.
      x: Numpy array.
      data_format: Image data format,
          either "channels_first" or "channels_last".
      file_format: Optional file format override. If omitted, the
          format to use is determined from the filename extension.
          If a file object was used instead of a filename, this
          parameter should always be used.
      scale: Whether to rescale image values to be within `[0, 255]`.
      **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
  """
  if data_format is None:
    data_format = backend.image_data_format()
  image.save_img(path,
                 x,
                 data_format=data_format,
                 file_format=file_format,
                 scale=scale, **kwargs)
def serialize_keras_object(instance)

Serialize a Keras object into a JSON-compatible representation.

Calls to serialize_keras_object() while underneath the SharedObjectSavingScope context manager will cause any objects re-used across multiple layers to be saved with a special shared object ID. This allows the network to be re-created properly during deserialization.

Args

instance
The object to serialize.

Returns

A dict-like, JSON-compatible representation of the object's config.

Expand source code
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
  """Serialize a Keras object into a JSON-compatible representation.

  Calls to `serialize_keras_object` while underneath the
  `SharedObjectSavingScope` context manager will cause any objects re-used
  across multiple layers to be saved with a special shared object ID. This
  allows the network to be re-created properly during deserialization.

  Args:
    instance: The object to serialize.

  Returns:
    A dict-like, JSON-compatible representation of the object's config.
  """
  _, instance = tf.__internal__.decorator.unwrap(instance)
  if instance is None:
    return None

  # pylint: disable=protected-access
  #
  # For v1 layers, checking supports_masking is not enough. We have to also
  # check whether compute_mask has been overridden.
  supports_masking = (getattr(instance, 'supports_masking', False)
                      or (hasattr(instance, 'compute_mask')
                          and not is_default(instance.compute_mask)))
  if supports_masking and is_default(instance.get_config):
    warnings.warn('Custom mask layers require a config and must override '
                  'get_config. When loading, the custom mask layer must be '
                  'passed to the custom_objects argument.',
                  category=CustomMaskWarning)
  # pylint: enable=protected-access

  if hasattr(instance, 'get_config'):
    name = get_registered_name(instance.__class__)
    try:
      config = instance.get_config()
    except NotImplementedError as e:
      if _SKIP_FAILED_SERIALIZATION:
        return serialize_keras_class_and_config(
            name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
      raise e
    serialization_config = {}
    for key, item in config.items():
      if isinstance(item, str):
        serialization_config[key] = item
        continue

      # Any object of a different type needs to be converted to string or dict
      # for serialization (e.g. custom functions, custom classes)
      try:
        serialized_item = serialize_keras_object(item)
        if isinstance(serialized_item, dict) and not isinstance(item, dict):
          serialized_item['__passive_serialization__'] = True
        serialization_config[key] = serialized_item
      except ValueError:
        serialization_config[key] = item

    name = get_registered_name(instance.__class__)
    return serialize_keras_class_and_config(
        name, serialization_config, instance)
  if hasattr(instance, '__name__'):
    return get_registered_name(instance)
  raise ValueError('Cannot serialize', instance)
def to_categorical(y, num_classes=None, dtype='float32')

Converts a class vector (integers) to binary class matrix.

E.g. for use with categorical_crossentropy.

Args

y
class vector to be converted into a matrix (integers from 0 to num_classes).
num_classes
total number of classes. If None, this would be inferred as the (largest number in y) + 1.
dtype
The data type expected by the input. Default: 'float32'.

Returns

A binary matrix representation of the input. The classes axis is placed last. Example:

>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> a = tf.constant(a, shape=[4, 4])
>>> print(a)
tf.Tensor(
  [[1. 0. 0. 0.]
   [0. 1. 0. 0.]
   [0. 0. 1. 0.]
   [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)
>>> b = tf.constant([.9, .04, .03, .03,
...                  .3, .45, .15, .13,
...                  .04, .01, .94, .05,
...                  .12, .21, .5, .17],
...                 shape=[4, 4])
>>> loss = tf.keras.backend.categorical_crossentropy(a, b)
>>> print(np.around(loss, 5))
[0.10536 0.82807 0.1011  1.77196]
>>> loss = tf.keras.backend.categorical_crossentropy(a, a)
>>> print(np.around(loss, 5))
[0. 0. 0. 0.]

Raises

Value Error
If input contains string value
Expand source code
@keras_export('keras.utils.to_categorical')
def to_categorical(y, num_classes=None, dtype='float32'):
  """Converts a class vector (integers) to binary class matrix.

  E.g. for use with categorical_crossentropy.

  Args:
      y: class vector to be converted into a matrix
          (integers from 0 to num_classes).
      num_classes: total number of classes. If `None`, this would be inferred
        as the (largest number in `y`) + 1.
      dtype: The data type expected by the input. Default: `'float32'`.

  Returns:
      A binary matrix representation of the input. The classes axis is placed
      last.

  Example:

  >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
  >>> a = tf.constant(a, shape=[4, 4])
  >>> print(a)
  tf.Tensor(
    [[1. 0. 0. 0.]
     [0. 1. 0. 0.]
     [0. 0. 1. 0.]
     [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)

  >>> b = tf.constant([.9, .04, .03, .03,
  ...                  .3, .45, .15, .13,
  ...                  .04, .01, .94, .05,
  ...                  .12, .21, .5, .17],
  ...                 shape=[4, 4])
  >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
  >>> print(np.around(loss, 5))
  [0.10536 0.82807 0.1011  1.77196]

  >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
  >>> print(np.around(loss, 5))
  [0. 0. 0. 0.]

  Raises:
      Value Error: If input contains string value

  """
  y = np.array(y, dtype='int')
  input_shape = y.shape
  if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
    input_shape = tuple(input_shape[:-1])
  y = y.ravel()
  if not num_classes:
    num_classes = np.max(y) + 1
  n = y.shape[0]
  categorical = np.zeros((n, num_classes), dtype=dtype)
  categorical[np.arange(n), y] = 1
  output_shape = input_shape + (num_classes,)
  categorical = np.reshape(categorical, output_shape)
  return categorical

Classes

class CustomObjectScope (*args)

Exposes custom classes/functions to Keras deserialization internals.

Under a scope with CustomObjectScope(objects_dict), Keras methods such as tf.keras.models.load_model or tf.keras.models.model_from_config will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric).

Example:

Consider a custom regularizer my_regularizer:

layer = Dense(3, kernel_regularizer=my_regularizer)
config = layer.get_config()  # Config contains a reference to `my_regularizer`
...
# Later:
with custom_object_scope({'my_regularizer': my_regularizer}):
  layer = Dense.from_config(config)

Args

*args
Dictionary or dictionaries of {name: object} pairs.
Expand source code
class CustomObjectScope(object):
  """Exposes custom classes/functions to Keras deserialization internals.

  Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
  as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
  will be able to deserialize any custom object referenced by a
  saved config (e.g. a custom layer or metric).

  Example:

  Consider a custom regularizer `my_regularizer`:

  ```python
  layer = Dense(3, kernel_regularizer=my_regularizer)
  config = layer.get_config()  # Config contains a reference to `my_regularizer`
  ...
  # Later:
  with custom_object_scope({'my_regularizer': my_regularizer}):
    layer = Dense.from_config(config)
  ```

  Args:
      *args: Dictionary or dictionaries of `{name: object}` pairs.
  """

  def __init__(self, *args):
    self.custom_objects = args
    self.backup = None

  def __enter__(self):
    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
    for objects in self.custom_objects:
      _GLOBAL_CUSTOM_OBJECTS.update(objects)
    return self

  def __exit__(self, *args, **kwargs):
    _GLOBAL_CUSTOM_OBJECTS.clear()
    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
class custom_object_scope (*args)

Exposes custom classes/functions to Keras deserialization internals.

Under a scope with CustomObjectScope(objects_dict), Keras methods such as tf.keras.models.load_model or tf.keras.models.model_from_config will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric).

Example:

Consider a custom regularizer my_regularizer:

layer = Dense(3, kernel_regularizer=my_regularizer)
config = layer.get_config()  # Config contains a reference to `my_regularizer`
...
# Later:
with custom_object_scope({'my_regularizer': my_regularizer}):
  layer = Dense.from_config(config)

Args

*args
Dictionary or dictionaries of {name: object} pairs.
Expand source code
class CustomObjectScope(object):
  """Exposes custom classes/functions to Keras deserialization internals.

  Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
  as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
  will be able to deserialize any custom object referenced by a
  saved config (e.g. a custom layer or metric).

  Example:

  Consider a custom regularizer `my_regularizer`:

  ```python
  layer = Dense(3, kernel_regularizer=my_regularizer)
  config = layer.get_config()  # Config contains a reference to `my_regularizer`
  ...
  # Later:
  with custom_object_scope({'my_regularizer': my_regularizer}):
    layer = Dense.from_config(config)
  ```

  Args:
      *args: Dictionary or dictionaries of `{name: object}` pairs.
  """

  def __init__(self, *args):
    self.custom_objects = args
    self.backup = None

  def __enter__(self):
    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
    for objects in self.custom_objects:
      _GLOBAL_CUSTOM_OBJECTS.update(objects)
    return self

  def __exit__(self, *args, **kwargs):
    _GLOBAL_CUSTOM_OBJECTS.clear()
    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
class GeneratorEnqueuer (generator, use_multiprocessing=False, random_seed=None)

Builds a queue out of a data generator.

The provided generator can be finite in which case the class will throw a StopIteration exception.

Args

generator
a generator function which yields data
use_multiprocessing
use multiprocessing if True, otherwise threading
random_seed
Initial seed for workers, will be incremented by one for each worker.
Expand source code
class GeneratorEnqueuer(SequenceEnqueuer):
  """Builds a queue out of a data generator.

  The provided generator can be finite in which case the class will throw
  a `StopIteration` exception.

  Args:
      generator: a generator function which yields data
      use_multiprocessing: use multiprocessing if True, otherwise threading
      random_seed: Initial seed for workers,
          will be incremented by one for each worker.
  """

  def __init__(self, generator,
               use_multiprocessing=False,
               random_seed=None):
    super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing)
    self.random_seed = random_seed

  def _get_executor_init(self, workers):
    """Gets the Pool initializer for multiprocessing.

    Args:
      workers: Number of works.

    Returns:
        A Function to initialize the pool
    """
    def pool_fn(seqs):
      pool = get_pool_class(True)(
          workers, initializer=init_pool_generator,
          initargs=(seqs, self.random_seed, get_worker_id_queue()))
      _DATA_POOLS.add(pool)
      return pool
    return pool_fn

  def _run(self):
    """Submits request to the executor and queue the `Future` objects."""
    self._send_sequence()  # Share the initial generator
    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
      while True:
        if self.stop_signal.is_set():
          return

        self.queue.put(
            executor.apply_async(next_sample, (self.uid,)), block=True)

  def get(self):
    """Creates a generator to extract data from the queue.

    Skip the data if it is `None`.

    Yields:
        The next element in the queue, i.e. a tuple
        `(inputs, targets)` or
        `(inputs, targets, sample_weights)`.
    """
    try:
      while self.is_running():
        inputs = self.queue.get(block=True).get()
        self.queue.task_done()
        if inputs is not None:
          yield inputs
    except StopIteration:
      # Special case for finite generators
      last_ones = []
      while self.queue.qsize() > 0:
        last_ones.append(self.queue.get(block=True))
      # Wait for them to complete
      for f in last_ones:
        f.wait()
      # Keep the good ones
      last_ones = [future.get() for future in last_ones if future.successful()]
      for inputs in last_ones:
        if inputs is not None:
          yield inputs
    except Exception as e:  # pylint: disable=broad-except
      self.stop()
      if 'generator already executing' in str(e):
        raise RuntimeError(
            'Your generator is NOT thread-safe. '
            'Keras requires a thread-safe generator when '
            '`use_multiprocessing=False, workers > 1`. ')
      raise e

Ancestors

Methods

def get(self)

Creates a generator to extract data from the queue.

Skip the data if it is None.

Yields

The next element in the queue, i.e. a tuple (inputs, targets) or (inputs, targets, sample_weights).

Expand source code
def get(self):
  """Creates a generator to extract data from the queue.

  Skip the data if it is `None`.

  Yields:
      The next element in the queue, i.e. a tuple
      `(inputs, targets)` or
      `(inputs, targets, sample_weights)`.
  """
  try:
    while self.is_running():
      inputs = self.queue.get(block=True).get()
      self.queue.task_done()
      if inputs is not None:
        yield inputs
  except StopIteration:
    # Special case for finite generators
    last_ones = []
    while self.queue.qsize() > 0:
      last_ones.append(self.queue.get(block=True))
    # Wait for them to complete
    for f in last_ones:
      f.wait()
    # Keep the good ones
    last_ones = [future.get() for future in last_ones if future.successful()]
    for inputs in last_ones:
      if inputs is not None:
        yield inputs
  except Exception as e:  # pylint: disable=broad-except
    self.stop()
    if 'generator already executing' in str(e):
      raise RuntimeError(
          'Your generator is NOT thread-safe. '
          'Keras requires a thread-safe generator when '
          '`use_multiprocessing=False, workers > 1`. ')
    raise e

Inherited members

class OrderedEnqueuer (sequence, use_multiprocessing=False, shuffle=False)

Builds a Enqueuer from a Sequence.

Args

sequence
A tf.keras.utils.data_utils.Sequence object.
use_multiprocessing
use multiprocessing if True, otherwise threading
shuffle
whether to shuffle the data at the beginning of each epoch
Expand source code
class OrderedEnqueuer(SequenceEnqueuer):
  """Builds a Enqueuer from a Sequence.

  Args:
      sequence: A `tf.keras.utils.data_utils.Sequence` object.
      use_multiprocessing: use multiprocessing if True, otherwise threading
      shuffle: whether to shuffle the data at the beginning of each epoch
  """

  def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
    super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
    self.shuffle = shuffle

  def _get_executor_init(self, workers):
    """Gets the Pool initializer for multiprocessing.

    Args:
        workers: Number of workers.

    Returns:
        Function, a Function to initialize the pool
    """
    def pool_fn(seqs):
      pool = get_pool_class(True)(
          workers, initializer=init_pool_generator,
          initargs=(seqs, None, get_worker_id_queue()))
      _DATA_POOLS.add(pool)
      return pool

    return pool_fn

  def _wait_queue(self):
    """Wait for the queue to be empty."""
    while True:
      time.sleep(0.1)
      if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
        return

  def _run(self):
    """Submits request to the executor and queue the `Future` objects."""
    sequence = list(range(len(self.sequence)))
    self._send_sequence()  # Share the initial sequence
    while True:
      if self.shuffle:
        random.shuffle(sequence)

      with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
        for i in sequence:
          if self.stop_signal.is_set():
            return

          self.queue.put(
              executor.apply_async(get_index, (self.uid, i)), block=True)

        # Done with the current epoch, waiting for the final batches
        self._wait_queue()

        if self.stop_signal.is_set():
          # We're done
          return

      # Call the internal on epoch end.
      self.sequence.on_epoch_end()
      self._send_sequence()  # Update the pool

  def get(self):
    """Creates a generator to extract data from the queue.

    Skip the data if it is `None`.

    Yields:
        The next element in the queue, i.e. a tuple
        `(inputs, targets)` or
        `(inputs, targets, sample_weights)`.
    """
    while self.is_running():
      try:
        inputs = self.queue.get(block=True, timeout=5).get()
        if self.is_running():
          self.queue.task_done()
        if inputs is not None:
          yield inputs
      except queue.Empty:
        pass
      except Exception as e:  # pylint: disable=broad-except
        self.stop()
        raise e

Ancestors

Methods

def get(self)

Creates a generator to extract data from the queue.

Skip the data if it is None.

Yields

The next element in the queue, i.e. a tuple (inputs, targets) or (inputs, targets, sample_weights).

Expand source code
def get(self):
  """Creates a generator to extract data from the queue.

  Skip the data if it is `None`.

  Yields:
      The next element in the queue, i.e. a tuple
      `(inputs, targets)` or
      `(inputs, targets, sample_weights)`.
  """
  while self.is_running():
    try:
      inputs = self.queue.get(block=True, timeout=5).get()
      if self.is_running():
        self.queue.task_done()
      if inputs is not None:
        yield inputs
    except queue.Empty:
      pass
    except Exception as e:  # pylint: disable=broad-except
      self.stop()
      raise e

Inherited members

class Progbar (target, width=30, verbose=1, interval=0.05, stateful_metrics=None, unit_name='step')

Displays a progress bar.

Args

target
Total number of steps expected, None if unknown.
width
Progress bar width on screen.
verbose
Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics
Iterable of string names of metrics that should not be averaged over time. Metrics in this list will be displayed as-is. All others will be averaged by the progbar before display.
interval
Minimum visual progress update interval (in seconds).
unit_name
Display name for step counts (usually "step" or "sample").
Expand source code
class Progbar(object):
  """Displays a progress bar.

  Args:
      target: Total number of steps expected, None if unknown.
      width: Progress bar width on screen.
      verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
      stateful_metrics: Iterable of string names of metrics that should *not* be
        averaged over time. Metrics in this list will be displayed as-is. All
        others will be averaged by the progbar before display.
      interval: Minimum visual progress update interval (in seconds).
      unit_name: Display name for step counts (usually "step" or "sample").
  """

  def __init__(self,
               target,
               width=30,
               verbose=1,
               interval=0.05,
               stateful_metrics=None,
               unit_name='step'):
    self.target = target
    self.width = width
    self.verbose = verbose
    self.interval = interval
    self.unit_name = unit_name
    if stateful_metrics:
      self.stateful_metrics = set(stateful_metrics)
    else:
      self.stateful_metrics = set()

    self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
                              sys.stdout.isatty()) or
                             'ipykernel' in sys.modules or
                             'posix' in sys.modules or
                             'PYCHARM_HOSTED' in os.environ)
    self._total_width = 0
    self._seen_so_far = 0
    # We use a dict + list to avoid garbage collection
    # issues found in OrderedDict
    self._values = {}
    self._values_order = []
    self._start = time.time()
    self._last_update = 0

    self._time_after_first_step = None

  def update(self, current, values=None, finalize=None):
    """Updates the progress bar.

    Args:
        current: Index of current step.
        values: List of tuples: `(name, value_for_last_step)`. If `name` is in
          `stateful_metrics`, `value_for_last_step` will be displayed as-is.
          Else, an average of the metric over time will be displayed.
        finalize: Whether this is the last update for the progress bar. If
          `None`, defaults to `current >= self.target`.
    """
    if finalize is None:
      if self.target is None:
        finalize = False
      else:
        finalize = current >= self.target

    values = values or []
    for k, v in values:
      if k not in self._values_order:
        self._values_order.append(k)
      if k not in self.stateful_metrics:
        # In the case that progress bar doesn't have a target value in the first
        # epoch, both on_batch_end and on_epoch_end will be called, which will
        # cause 'current' and 'self._seen_so_far' to have the same value. Force
        # the minimal value to 1 here, otherwise stateful_metric will be 0s.
        value_base = max(current - self._seen_so_far, 1)
        if k not in self._values:
          self._values[k] = [v * value_base, value_base]
        else:
          self._values[k][0] += v * value_base
          self._values[k][1] += value_base
      else:
        # Stateful metrics output a numeric value. This representation
        # means "take an average from a single value" but keeps the
        # numeric formatting.
        self._values[k] = [v, 1]
    self._seen_so_far = current

    now = time.time()
    info = ' - %.0fs' % (now - self._start)
    if self.verbose == 1:
      if now - self._last_update < self.interval and not finalize:
        return

      prev_total_width = self._total_width
      if self._dynamic_display:
        sys.stdout.write('\b' * prev_total_width)
        sys.stdout.write('\r')
      else:
        sys.stdout.write('\n')

      if self.target is not None:
        numdigits = int(np.log10(self.target)) + 1
        bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
        prog = float(current) / self.target
        prog_width = int(self.width * prog)
        if prog_width > 0:
          bar += ('=' * (prog_width - 1))
          if current < self.target:
            bar += '>'
          else:
            bar += '='
        bar += ('.' * (self.width - prog_width))
        bar += ']'
      else:
        bar = '%7d/Unknown' % current

      self._total_width = len(bar)
      sys.stdout.write(bar)

      time_per_unit = self._estimate_step_duration(current, now)

      if self.target is None or finalize:
        if time_per_unit >= 1 or time_per_unit == 0:
          info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
        elif time_per_unit >= 1e-3:
          info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
        else:
          info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
      else:
        eta = time_per_unit * (self.target - current)
        if eta > 3600:
          eta_format = '%d:%02d:%02d' % (eta // 3600,
                                         (eta % 3600) // 60, eta % 60)
        elif eta > 60:
          eta_format = '%d:%02d' % (eta // 60, eta % 60)
        else:
          eta_format = '%ds' % eta

        info = ' - ETA: %s' % eta_format

      for k in self._values_order:
        info += ' - %s:' % k
        if isinstance(self._values[k], list):
          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
          if abs(avg) > 1e-3:
            info += ' %.4f' % avg
          else:
            info += ' %.4e' % avg
        else:
          info += ' %s' % self._values[k]

      self._total_width += len(info)
      if prev_total_width > self._total_width:
        info += (' ' * (prev_total_width - self._total_width))

      if finalize:
        info += '\n'

      sys.stdout.write(info)
      sys.stdout.flush()

    elif self.verbose == 2:
      if finalize:
        numdigits = int(np.log10(self.target)) + 1
        count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
        info = count + info
        for k in self._values_order:
          info += ' - %s:' % k
          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
          if avg > 1e-3:
            info += ' %.4f' % avg
          else:
            info += ' %.4e' % avg
        info += '\n'

        sys.stdout.write(info)
        sys.stdout.flush()

    self._last_update = now

  def add(self, n, values=None):
    self.update(self._seen_so_far + n, values)

  def _estimate_step_duration(self, current, now):
    """Estimate the duration of a single step.

    Given the step number `current` and the corresponding time `now`
    this function returns an estimate for how long a single step
    takes. If this is called before one step has been completed
    (i.e. `current == 0`) then zero is given as an estimate. The duration
    estimate ignores the duration of the (assumed to be non-representative)
    first step for estimates when more steps are available (i.e. `current>1`).
    Args:
      current: Index of current step.
      now: The current time.
    Returns: Estimate of the duration of a single step.
    """
    if current:
      # there are a few special scenarios here:
      # 1) somebody is calling the progress bar without ever supplying step 1
      # 2) somebody is calling the progress bar and supplies step one mulitple
      #    times, e.g. as part of a finalizing call
      # in these cases, we just fall back to the simple calculation
      if self._time_after_first_step is not None and current > 1:
        time_per_unit = (now - self._time_after_first_step) / (current - 1)
      else:
        time_per_unit = (now - self._start) / current

      if current == 1:
        self._time_after_first_step = now
      return time_per_unit
    else:
      return 0

  def _update_stateful_metrics(self, stateful_metrics):
    self.stateful_metrics = self.stateful_metrics.union(stateful_metrics)

Methods

def add(self, n, values=None)
Expand source code
def add(self, n, values=None):
  self.update(self._seen_so_far + n, values)
def update(self, current, values=None, finalize=None)

Updates the progress bar.

Args

current
Index of current step.
values
List of tuples: (name, value_for_last_step). If name is in stateful_metrics, value_for_last_step will be displayed as-is. Else, an average of the metric over time will be displayed.
finalize
Whether this is the last update for the progress bar. If None, defaults to current >= self.target.
Expand source code
def update(self, current, values=None, finalize=None):
  """Updates the progress bar.

  Args:
      current: Index of current step.
      values: List of tuples: `(name, value_for_last_step)`. If `name` is in
        `stateful_metrics`, `value_for_last_step` will be displayed as-is.
        Else, an average of the metric over time will be displayed.
      finalize: Whether this is the last update for the progress bar. If
        `None`, defaults to `current >= self.target`.
  """
  if finalize is None:
    if self.target is None:
      finalize = False
    else:
      finalize = current >= self.target

  values = values or []
  for k, v in values:
    if k not in self._values_order:
      self._values_order.append(k)
    if k not in self.stateful_metrics:
      # In the case that progress bar doesn't have a target value in the first
      # epoch, both on_batch_end and on_epoch_end will be called, which will
      # cause 'current' and 'self._seen_so_far' to have the same value. Force
      # the minimal value to 1 here, otherwise stateful_metric will be 0s.
      value_base = max(current - self._seen_so_far, 1)
      if k not in self._values:
        self._values[k] = [v * value_base, value_base]
      else:
        self._values[k][0] += v * value_base
        self._values[k][1] += value_base
    else:
      # Stateful metrics output a numeric value. This representation
      # means "take an average from a single value" but keeps the
      # numeric formatting.
      self._values[k] = [v, 1]
  self._seen_so_far = current

  now = time.time()
  info = ' - %.0fs' % (now - self._start)
  if self.verbose == 1:
    if now - self._last_update < self.interval and not finalize:
      return

    prev_total_width = self._total_width
    if self._dynamic_display:
      sys.stdout.write('\b' * prev_total_width)
      sys.stdout.write('\r')
    else:
      sys.stdout.write('\n')

    if self.target is not None:
      numdigits = int(np.log10(self.target)) + 1
      bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
      prog = float(current) / self.target
      prog_width = int(self.width * prog)
      if prog_width > 0:
        bar += ('=' * (prog_width - 1))
        if current < self.target:
          bar += '>'
        else:
          bar += '='
      bar += ('.' * (self.width - prog_width))
      bar += ']'
    else:
      bar = '%7d/Unknown' % current

    self._total_width = len(bar)
    sys.stdout.write(bar)

    time_per_unit = self._estimate_step_duration(current, now)

    if self.target is None or finalize:
      if time_per_unit >= 1 or time_per_unit == 0:
        info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
      elif time_per_unit >= 1e-3:
        info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
      else:
        info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
    else:
      eta = time_per_unit * (self.target - current)
      if eta > 3600:
        eta_format = '%d:%02d:%02d' % (eta // 3600,
                                       (eta % 3600) // 60, eta % 60)
      elif eta > 60:
        eta_format = '%d:%02d' % (eta // 60, eta % 60)
      else:
        eta_format = '%ds' % eta

      info = ' - ETA: %s' % eta_format

    for k in self._values_order:
      info += ' - %s:' % k
      if isinstance(self._values[k], list):
        avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
        if abs(avg) > 1e-3:
          info += ' %.4f' % avg
        else:
          info += ' %.4e' % avg
      else:
        info += ' %s' % self._values[k]

    self._total_width += len(info)
    if prev_total_width > self._total_width:
      info += (' ' * (prev_total_width - self._total_width))

    if finalize:
      info += '\n'

    sys.stdout.write(info)
    sys.stdout.flush()

  elif self.verbose == 2:
    if finalize:
      numdigits = int(np.log10(self.target)) + 1
      count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
      info = count + info
      for k in self._values_order:
        info += ' - %s:' % k
        avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
        if avg > 1e-3:
          info += ' %.4f' % avg
        else:
          info += ' %.4e' % avg
      info += '\n'

      sys.stdout.write(info)
      sys.stdout.flush()

  self._last_update = now
class Sequence

Base object for fitting to a sequence of data, such as a dataset.

Every Sequence must implement the __getitem__ and the __len__ methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method __getitem__ should return a complete batch.

Notes:

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

Examples:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)
Expand source code
class Sequence(object):
  """Base object for fitting to a sequence of data, such as a dataset.

  Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
  If you want to modify your dataset between epochs you may implement
  `on_epoch_end`.
  The method `__getitem__` should return a complete batch.

  Notes:

  `Sequence` are a safer way to do multiprocessing. This structure guarantees
  that the network will only train once
   on each sample per epoch which is not the case with generators.

  Examples:

  ```python
  from skimage.io import imread
  from skimage.transform import resize
  import numpy as np
  import math

  # Here, `x_set` is list of path to the images
  # and `y_set` are the associated classes.

  class CIFAR10Sequence(Sequence):

      def __init__(self, x_set, y_set, batch_size):
          self.x, self.y = x_set, y_set
          self.batch_size = batch_size

      def __len__(self):
          return math.ceil(len(self.x) / self.batch_size)

      def __getitem__(self, idx):
          batch_x = self.x[idx * self.batch_size:(idx + 1) *
          self.batch_size]
          batch_y = self.y[idx * self.batch_size:(idx + 1) *
          self.batch_size]

          return np.array([
              resize(imread(file_name), (200, 200))
                 for file_name in batch_x]), np.array(batch_y)
  ```
  """

  @abstractmethod
  def __getitem__(self, index):
    """Gets batch at position `index`.

    Args:
        index: position of the batch in the Sequence.

    Returns:
        A batch
    """
    raise NotImplementedError

  @abstractmethod
  def __len__(self):
    """Number of batch in the Sequence.

    Returns:
        The number of batches in the Sequence.
    """
    raise NotImplementedError

  def on_epoch_end(self):
    """Method called at the end of every epoch.
    """
    pass

  def __iter__(self):
    """Create a generator that iterate over the Sequence."""
    for item in (self[i] for i in range(len(self))):
      yield item

Subclasses

Methods

def on_epoch_end(self)

Method called at the end of every epoch.

Expand source code
def on_epoch_end(self):
  """Method called at the end of every epoch.
  """
  pass
class SequenceEnqueuer (sequence, use_multiprocessing=False)

Base class to enqueue inputs.

The task of an Enqueuer is to use parallelism to speed up preprocessing. This is done with processes or threads.

Example:

    enqueuer = SequenceEnqueuer(...)
    enqueuer.start()
    datas = enqueuer.get()
    for data in datas:
        # Use the inputs; training, evaluating, predicting.
        # ... stop sometime.
    enqueuer.stop()

The enqueuer.get() should be an infinite stream of datas.

Expand source code
class SequenceEnqueuer(object):
  """Base class to enqueue inputs.

  The task of an Enqueuer is to use parallelism to speed up preprocessing.
  This is done with processes or threads.

  Example:

  ```python
      enqueuer = SequenceEnqueuer(...)
      enqueuer.start()
      datas = enqueuer.get()
      for data in datas:
          # Use the inputs; training, evaluating, predicting.
          # ... stop sometime.
      enqueuer.stop()
  ```

  The `enqueuer.get()` should be an infinite stream of datas.
  """

  def __init__(self, sequence,
               use_multiprocessing=False):
    self.sequence = sequence
    self.use_multiprocessing = use_multiprocessing

    global _SEQUENCE_COUNTER
    if _SEQUENCE_COUNTER is None:
      try:
        _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
      except OSError:
        # In this case the OS does not allow us to use
        # multiprocessing. We resort to an int
        # for enqueuer indexing.
        _SEQUENCE_COUNTER = 0

    if isinstance(_SEQUENCE_COUNTER, int):
      self.uid = _SEQUENCE_COUNTER
      _SEQUENCE_COUNTER += 1
    else:
      # Doing Multiprocessing.Value += x is not process-safe.
      with _SEQUENCE_COUNTER.get_lock():
        self.uid = _SEQUENCE_COUNTER.value
        _SEQUENCE_COUNTER.value += 1

    self.workers = 0
    self.executor_fn = None
    self.queue = None
    self.run_thread = None
    self.stop_signal = None

  def is_running(self):
    return self.stop_signal is not None and not self.stop_signal.is_set()

  def start(self, workers=1, max_queue_size=10):
    """Starts the handler's workers.

    Args:
        workers: Number of workers.
        max_queue_size: queue size
            (when full, workers could block on `put()`)
    """
    if self.use_multiprocessing:
      self.executor_fn = self._get_executor_init(workers)
    else:
      # We do not need the init since it's threads.
      self.executor_fn = lambda _: get_pool_class(False)(workers)
    self.workers = workers
    self.queue = queue.Queue(max_queue_size)
    self.stop_signal = threading.Event()
    self.run_thread = threading.Thread(target=self._run)
    self.run_thread.daemon = True
    self.run_thread.start()

  def _send_sequence(self):
    """Sends current Iterable to all workers."""
    # For new processes that may spawn
    _SHARED_SEQUENCES[self.uid] = self.sequence

  def stop(self, timeout=None):
    """Stops running threads and wait for them to exit, if necessary.

    Should be called by the same thread which called `start()`.

    Args:
        timeout: maximum time to wait on `thread.join()`
    """
    self.stop_signal.set()
    with self.queue.mutex:
      self.queue.queue.clear()
      self.queue.unfinished_tasks = 0
      self.queue.not_full.notify()
    self.run_thread.join(timeout)
    _SHARED_SEQUENCES[self.uid] = None

  def __del__(self):
    if self.is_running():
      self.stop()

  @abstractmethod
  def _run(self):
    """Submits request to the executor and queue the `Future` objects."""
    raise NotImplementedError

  @abstractmethod
  def _get_executor_init(self, workers):
    """Gets the Pool initializer for multiprocessing.

    Args:
        workers: Number of workers.

    Returns:
        Function, a Function to initialize the pool
    """
    raise NotImplementedError

  @abstractmethod
  def get(self):
    """Creates a generator to extract data from the queue.

    Skip the data if it is `None`.
    # Returns
        Generator yielding tuples `(inputs, targets)`
            or `(inputs, targets, sample_weights)`.
    """
    raise NotImplementedError

Subclasses

Methods

def get(self)

Creates a generator to extract data from the queue.

Skip the data if it is None.

Returns

Generator yielding tuples <code>(inputs, targets)</code>
    or <code>(inputs, targets, sample\_weights)</code>.
Expand source code
@abstractmethod
def get(self):
  """Creates a generator to extract data from the queue.

  Skip the data if it is `None`.
  # Returns
      Generator yielding tuples `(inputs, targets)`
          or `(inputs, targets, sample_weights)`.
  """
  raise NotImplementedError
def is_running(self)
Expand source code
def is_running(self):
  return self.stop_signal is not None and not self.stop_signal.is_set()
def start(self, workers=1, max_queue_size=10)

Starts the handler's workers.

Args

workers
Number of workers.
max_queue_size
queue size (when full, workers could block on put())
Expand source code
def start(self, workers=1, max_queue_size=10):
  """Starts the handler's workers.

  Args:
      workers: Number of workers.
      max_queue_size: queue size
          (when full, workers could block on `put()`)
  """
  if self.use_multiprocessing:
    self.executor_fn = self._get_executor_init(workers)
  else:
    # We do not need the init since it's threads.
    self.executor_fn = lambda _: get_pool_class(False)(workers)
  self.workers = workers
  self.queue = queue.Queue(max_queue_size)
  self.stop_signal = threading.Event()
  self.run_thread = threading.Thread(target=self._run)
  self.run_thread.daemon = True
  self.run_thread.start()
def stop(self, timeout=None)

Stops running threads and wait for them to exit, if necessary.

Should be called by the same thread which called start().

Args

timeout
maximum time to wait on thread.join()
Expand source code
def stop(self, timeout=None):
  """Stops running threads and wait for them to exit, if necessary.

  Should be called by the same thread which called `start()`.

  Args:
      timeout: maximum time to wait on `thread.join()`
  """
  self.stop_signal.set()
  with self.queue.mutex:
    self.queue.queue.clear()
    self.queue.unfinished_tasks = 0
    self.queue.not_full.notify()
  self.run_thread.join(timeout)
  _SHARED_SEQUENCES[self.uid] = None