Module keras.api.keras.utils
Public API for tf.keras.utils namespace.
Expand source code
# This file is MACHINE GENERATED! Do not edit.
# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
"""Public API for tf.keras.utils namespace.
"""
from __future__ import print_function as _print_function
import sys as _sys
from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from keras.preprocessing.image import save_img
from keras.utils.data_utils import GeneratorEnqueuer
from keras.utils.data_utils import OrderedEnqueuer
from keras.utils.data_utils import Sequence
from keras.utils.data_utils import SequenceEnqueuer
from keras.utils.data_utils import get_file
from keras.utils.generic_utils import CustomObjectScope
from keras.utils.generic_utils import CustomObjectScope as custom_object_scope
from keras.utils.generic_utils import Progbar
from keras.utils.generic_utils import deserialize_keras_object
from keras.utils.generic_utils import get_custom_objects
from keras.utils.generic_utils import get_registered_name
from keras.utils.generic_utils import get_registered_object
from keras.utils.generic_utils import register_keras_serializable
from keras.utils.generic_utils import serialize_keras_object
from keras.utils.layer_utils import get_source_inputs
from keras.utils.np_utils import normalize
from keras.utils.np_utils import to_categorical
from keras.utils.vis_utils import model_to_dot
from keras.utils.vis_utils import plot_model
del _print_function
from tensorflow.python.util import module_wrapper as _module_wrapper
if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
_sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
_sys.modules[__name__], "keras.utils", public_apis=None, deprecation=True,
has_lite=False)
Functions
def array_to_img(x, data_format=None, scale=True, dtype=None)
-
Converts a 3D Numpy array to a PIL Image instance.
Usage:
from PIL import Image img = np.random.random(size=(100, 100, 3)) pil_img = tf.keras.preprocessing.image.array_to_img(img)
Args
x
- Input data, in any form that can be converted to a Numpy array.
data_format
- Image data format, can be either "channels_first" or
"channels_last". Defaults to
None
, in which case the global settingtf.keras.backend.image_data_format()
is used (unless you changed it, it defaults to "channels_last"). scale
- Whether to rescale the image such that minimum and maximum values
are 0 and 255 respectively. Defaults to
True
. dtype
- Dtype to use. Default to
None
, in which case the global setting
tf.keras.backend.floatx()
is used (unless you changed it, it defaults to "float32")Returns
A PIL Image instance.
Raises
ImportError
- if PIL is not available.
ValueError
- if invalid
x
ordata_format
is passed.
Expand source code
@keras_export('keras.utils.array_to_img', 'keras.preprocessing.image.array_to_img') def array_to_img(x, data_format=None, scale=True, dtype=None): """Converts a 3D Numpy array to a PIL Image instance. Usage: ```python from PIL import Image img = np.random.random(size=(100, 100, 3)) pil_img = tf.keras.preprocessing.image.array_to_img(img) ``` Args: x: Input data, in any form that can be converted to a Numpy array. data_format: Image data format, can be either "channels_first" or "channels_last". Defaults to `None`, in which case the global setting `tf.keras.backend.image_data_format()` is used (unless you changed it, it defaults to "channels_last"). scale: Whether to rescale the image such that minimum and maximum values are 0 and 255 respectively. Defaults to `True`. dtype: Dtype to use. Default to `None`, in which case the global setting `tf.keras.backend.floatx()` is used (unless you changed it, it defaults to "float32") Returns: A PIL Image instance. Raises: ImportError: if PIL is not available. ValueError: if invalid `x` or `data_format` is passed. """ if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object')
-
Turns the serialized form of a Keras object back into an actual object.
This function is for mid-level library implementers rather than end users.
Importantly, this utility requires you to provide the dict of
module_objects
to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g.keras.layers.deserialize(config)
,deserialize()(config)
, etc.Calling
deserialize_keras_object()
while underneath theSharedObjectLoadingScope
context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object.Args
identifier
- the serialized form of the object.
module_objects
- A dictionary of built-in objects to look the name up in.
Generally,
module_objects
is provided by midlevel library implementers. custom_objects
- A dictionary of custom objects to look the name up in.
Generally,
custom_objects
is provided by the end user. printable_module_name
- A human-readable string representing the type of the object. Printed in case of exception.
Returns
The deserialized object. Example:
A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such:
def deserialize(config, custom_objects=None): return deserialize_keras_object( identifier, module_objects=globals(), custom_objects=custom_objects, name="MyObjectType", )
This is how e.g.
keras.layers.deserialize()
is implemented.Expand source code
@keras_export('keras.utils.deserialize_keras_object') def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object. This function is for mid-level library implementers rather than end users. Importantly, this utility requires you to provide the dict of `module_objects` to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, etc. Calling `deserialize_keras_object` while underneath the `SharedObjectLoadingScope` context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object. Args: identifier: the serialized form of the object. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. printable_module_name: A human-readable string representing the type of the object. Printed in case of exception. Returns: The deserialized object. Example: A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such: ```python def deserialize(config, custom_objects=None): return deserialize_keras_object( identifier, module_objects=globals(), custom_objects=custom_objects, name="MyObjectType", ) ``` This is how e.g. `keras.layers.deserialize()` is implemented. """ if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) # If this object has already been loaded (i.e. it's shared between multiple # objects), return the already-loaded object. shared_object_id = config.get(SHARED_OBJECT_KEY) shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none if shared_object is not None: return shared_object if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: deserialized_obj = cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) else: with CustomObjectScope(custom_objects): deserialized_obj = cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): deserialized_obj = cls(**cls_config) # Add object to shared objects, in case we find it referenced again. _shared_object_loading_scope().set(shared_object_id, deserialized_obj) return deserialized_obj elif isinstance(identifier, str): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.' .format(printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def get_custom_objects()
-
Retrieves a live reference to the global dictionary of custom objects.
Updating and clearing custom objects using
CustomObjectScope
is preferred, butget_custom_objects()
can be used to directly access the current collection of custom objects.Example:
get_custom_objects().clear() get_custom_objects()['MyObject'] = MyObject
Returns
Global dictionary of names to classes (
_GLOBAL_CUSTOM_OBJECTS
).Expand source code
@keras_export('keras.utils.get_custom_objects') def get_custom_objects(): """Retrieves a live reference to the global dictionary of custom objects. Updating and clearing custom objects using `custom_object_scope` is preferred, but `get_custom_objects` can be used to directly access the current collection of custom objects. Example: ```python get_custom_objects().clear() get_custom_objects()['MyObject'] = MyObject ``` Returns: Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). """ return _GLOBAL_CUSTOM_OBJECTS
def get_file(fname=None, origin=None, untar=False, md5_hash=None, file_hash=None, cache_subdir='datasets', hash_algorithm='auto', extract=False, archive_format='auto', cache_dir=None)
-
Downloads a file from a URL if it not already in the cache.
By default the file at the url
origin
is downloaded to the cache_dir~/.keras
, placed in the cache_subdirdatasets
, and given the filenamefname
. The final location of a fileexample.txt
would therefore be~/.keras/datasets/example.txt
.Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs
shasum
andsha256sum
can compute the hash.Example:
path_to_downloaded_file = tf.keras.utils.get_file( "flower_photos", "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", untar=True)
Args
fname
- Name of the file. If an absolute path
/path/to/file.txt
is specified the file will be saved at that location. IfNone
, the name of the file atorigin
will be used. origin
- Original URL of the file.
untar
- Deprecated in favor of
extract
argument. boolean, whether the file should be decompressed md5_hash
- Deprecated in favor of
file_hash
argument. md5 hash of the file for verification file_hash
- The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported.
cache_subdir
- Subdirectory under the Keras cache dir where the file is
saved. If an absolute path
/path/to/folder
is specified the file will be saved at that location. hash_algorithm
- Select the hash algorithm to verify the file.
options are
'md5'
,'sha256'
, and'auto'
. The default 'auto' detects the hash algorithm in use. extract
- True tries extracting the file as an Archive, like tar or zip.
archive_format
- Archive format to try for extracting the file.
Options are
'auto'
,'tar'
,'zip'
, andNone
.'tar'
includes tar, tar.gz, and tar.bz files. The default'auto'
corresponds to['tar', 'zip']
. None or an empty list will return no matches found. cache_dir
- Location to store cached files, when None it
defaults to the default directory
~/.keras/
.
Returns
Path to the downloaded file
Expand source code
@keras_export('keras.utils.get_file') def get_file(fname=None, origin=None, untar=False, md5_hash=None, file_hash=None, cache_subdir='datasets', hash_algorithm='auto', extract=False, archive_format='auto', cache_dir=None): """Downloads a file from a URL if it not already in the cache. By default the file at the url `origin` is downloaded to the cache_dir `~/.keras`, placed in the cache_subdir `datasets`, and given the filename `fname`. The final location of a file `example.txt` would therefore be `~/.keras/datasets/example.txt`. Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs `shasum` and `sha256sum` can compute the hash. Example: ```python path_to_downloaded_file = tf.keras.utils.get_file( "flower_photos", "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", untar=True) ``` Args: fname: Name of the file. If an absolute path `/path/to/file.txt` is specified the file will be saved at that location. If `None`, the name of the file at `origin` will be used. origin: Original URL of the file. untar: Deprecated in favor of `extract` argument. boolean, whether the file should be decompressed md5_hash: Deprecated in favor of `file_hash` argument. md5 hash of the file for verification file_hash: The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported. cache_subdir: Subdirectory under the Keras cache dir where the file is saved. If an absolute path `/path/to/folder` is specified the file will be saved at that location. hash_algorithm: Select the hash algorithm to verify the file. options are `'md5'`, `'sha256'`, and `'auto'`. The default 'auto' detects the hash algorithm in use. extract: True tries extracting the file as an Archive, like tar or zip. archive_format: Archive format to try for extracting the file. Options are `'auto'`, `'tar'`, `'zip'`, and `None`. `'tar'` includes tar, tar.gz, and tar.bz files. The default `'auto'` corresponds to `['tar', 'zip']`. None or an empty list will return no matches found. cache_dir: Location to store cached files, when None it defaults to the default directory `~/.keras/`. Returns: Path to the downloaded file """ if origin is None: raise ValueError('Please specify the "origin" argument (URL of the file ' 'to download).') if cache_dir is None: cache_dir = os.path.join(os.path.expanduser('~'), '.keras') if md5_hash is not None and file_hash is None: file_hash = md5_hash hash_algorithm = 'md5' datadir_base = os.path.expanduser(cache_dir) if not os.access(datadir_base, os.W_OK): datadir_base = os.path.join('/tmp', '.keras') datadir = os.path.join(datadir_base, cache_subdir) _makedirs_exist_ok(datadir) fname = path_to_string(fname) if not fname: fname = os.path.basename(urlsplit(origin).path) if not fname: raise ValueError("Invalid origin '{}'".format(origin)) if untar: if fname.endswith('.tar.gz'): fname = pathlib.Path(fname) # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib # considers it as 2 suffixes. fname = fname.with_suffix('').with_suffix('') fname = str(fname) untar_fpath = os.path.join(datadir, fname) fpath = untar_fpath + '.tar.gz' else: fpath = os.path.join(datadir, fname) download = False if os.path.exists(fpath): # File found; verify integrity if a hash was provided. if file_hash is not None: if not validate_file(fpath, file_hash, algorithm=hash_algorithm): print('A local file was found, but it seems to be ' 'incomplete or outdated because the ' + hash_algorithm + ' file hash does not match the original value of ' + file_hash + ' so we will re-download the data.') download = True else: download = True if download: print('Downloading data from', origin) class ProgressTracker(object): # Maintain progbar for the lifetime of download. # This design was chosen for Python 2.7 compatibility. progbar = None def dl_progress(count, block_size, total_size): if ProgressTracker.progbar is None: if total_size == -1: total_size = None ProgressTracker.progbar = Progbar(total_size) else: ProgressTracker.progbar.update(count * block_size) error_msg = 'URL fetch failure on {}: {} -- {}' try: try: urlretrieve(origin, fpath, dl_progress) except urllib.error.HTTPError as e: raise Exception(error_msg.format(origin, e.code, e.msg)) except urllib.error.URLError as e: raise Exception(error_msg.format(origin, e.errno, e.reason)) except (Exception, KeyboardInterrupt) as e: if os.path.exists(fpath): os.remove(fpath) raise ProgressTracker.progbar = None if untar: if not os.path.exists(untar_fpath): _extract_archive(fpath, datadir, archive_format='tar') return untar_fpath if extract: _extract_archive(fpath, datadir, archive_format) return fpath
def get_registered_name(obj)
-
Returns the name registered to an object within the Keras framework.
This function is part of the Keras serialization and deserialization framework. It maps objects to the string names associated with those objects for serialization/deserialization.
Args
obj
- The object to look up.
Returns
The name associated with the object, or the default Python name if the object is not registered.
Expand source code
@keras_export('keras.utils.get_registered_name') def get_registered_name(obj): """Returns the name registered to an object within the Keras framework. This function is part of the Keras serialization and deserialization framework. It maps objects to the string names associated with those objects for serialization/deserialization. Args: obj: The object to look up. Returns: The name associated with the object, or the default Python name if the object is not registered. """ if obj in _GLOBAL_CUSTOM_NAMES: return _GLOBAL_CUSTOM_NAMES[obj] else: return obj.__name__
def get_registered_object(name, custom_objects=None, module_objects=None)
-
Returns the class associated with
name
if it is registered with Keras.This function is part of the Keras serialization and deserialization framework. It maps strings to the objects associated with them for serialization/deserialization.
Example:
def from_config(cls, config, custom_objects=None): if 'my_custom_object_name' in config: config['hidden_cls'] = tf.keras.utils.get_registered_object( config['my_custom_object_name'], custom_objects=custom_objects)
Args
name
- The name to look up.
custom_objects
- A dictionary of custom objects to look the name up in. Generally, custom_objects is provided by the user.
module_objects
- A dictionary of custom objects to look the name up in. Generally, module_objects is provided by midlevel library implementers.
Returns
An instantiable class associated with 'name', or None if no such class exists.
Expand source code
@keras_export('keras.utils.get_registered_object') def get_registered_object(name, custom_objects=None, module_objects=None): """Returns the class associated with `name` if it is registered with Keras. This function is part of the Keras serialization and deserialization framework. It maps strings to the objects associated with them for serialization/deserialization. Example: ``` def from_config(cls, config, custom_objects=None): if 'my_custom_object_name' in config: config['hidden_cls'] = tf.keras.utils.get_registered_object( config['my_custom_object_name'], custom_objects=custom_objects) ``` Args: name: The name to look up. custom_objects: A dictionary of custom objects to look the name up in. Generally, custom_objects is provided by the user. module_objects: A dictionary of custom objects to look the name up in. Generally, module_objects is provided by midlevel library implementers. Returns: An instantiable class associated with 'name', or None if no such class exists. """ if name in _GLOBAL_CUSTOM_OBJECTS: return _GLOBAL_CUSTOM_OBJECTS[name] elif custom_objects and name in custom_objects: return custom_objects[name] elif module_objects and name in module_objects: return module_objects[name] return None
def get_source_inputs(tensor, layer=None, node_index=None)
-
Returns the list of input tensors necessary to compute
tensor
.Output will always be a list of tensors (potentially with 1 element).
Args
tensor
- The tensor to start from.
layer
- Origin layer of the tensor. Will be determined via tensor._keras_history if not provided.
node_index
- Origin node index of the tensor.
Returns
List of input tensors.
Expand source code
@keras_export('keras.utils.get_source_inputs') def get_source_inputs(tensor, layer=None, node_index=None): """Returns the list of input tensors necessary to compute `tensor`. Output will always be a list of tensors (potentially with 1 element). Args: tensor: The tensor to start from. layer: Origin layer of the tensor. Will be determined via tensor._keras_history if not provided. node_index: Origin node index of the tensor. Returns: List of input tensors. """ if not hasattr(tensor, '_keras_history'): return tensor if layer is None or node_index: layer, node_index, _ = tensor._keras_history if not layer._inbound_nodes: return [tensor] else: node = layer._inbound_nodes[node_index] if node.is_input: # Reached an Input layer, stop recursion. return tf.nest.flatten(node.input_tensors) else: source_tensors = [] for layer, node_index, _, tensor in node.iterate_inbound(): previous_sources = get_source_inputs(tensor, layer, node_index) # Avoid input redundancy. for x in previous_sources: if all(x is not t for t in source_tensors): source_tensors.append(x) return source_tensors
def img_to_array(img, data_format=None, dtype=None)
-
Converts a PIL Image instance to a Numpy array.
Usage:
from PIL import Image img_data = np.random.random(size=(100, 100, 3)) img = tf.keras.preprocessing.image.array_to_img(img_data) array = tf.keras.preprocessing.image.img_to_array(img)
Args
img
- Input PIL Image instance.
data_format
- Image data format, can be either "channels_first" or
"channels_last". Defaults to
None
, in which case the global settingtf.keras.backend.image_data_format()
is used (unless you changed it, it defaults to "channels_last"). dtype
- Dtype to use. Default to
None
, in which case the global setting
tf.keras.backend.floatx()
is used (unless you changed it, it defaults to "float32")Returns
A 3D Numpy array.
Raises
ValueError
- if invalid
img
ordata_format
is passed.
Expand source code
@keras_export('keras.utils.img_to_array', 'keras.preprocessing.image.img_to_array') def img_to_array(img, data_format=None, dtype=None): """Converts a PIL Image instance to a Numpy array. Usage: ```python from PIL import Image img_data = np.random.random(size=(100, 100, 3)) img = tf.keras.preprocessing.image.array_to_img(img_data) array = tf.keras.preprocessing.image.img_to_array(img) ``` Args: img: Input PIL Image instance. data_format: Image data format, can be either "channels_first" or "channels_last". Defaults to `None`, in which case the global setting `tf.keras.backend.image_data_format()` is used (unless you changed it, it defaults to "channels_last"). dtype: Dtype to use. Default to `None`, in which case the global setting `tf.keras.backend.floatx()` is used (unless you changed it, it defaults to "float32") Returns: A 3D Numpy array. Raises: ValueError: if invalid `img` or `data_format` is passed. """ if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype return image.img_to_array(img, data_format=data_format, **kwargs)
def load_img(path, grayscale=False, color_mode='rgb', target_size=None, interpolation='nearest')
-
Loads an image into PIL format.
Usage:
image = tf.keras.preprocessing.image.load_img(image_path) input_arr = tf.keras.preprocessing.image.img_to_array(image) input_arr = np.array([input_arr]) # Convert single image to a batch. predictions = model.predict(input_arr)
Args
path
- Path to image file.
grayscale
- DEPRECATED use
color_mode="grayscale"
. color_mode
- One of "grayscale", "rgb", "rgba". Default: "rgb". The desired image format.
target_size
- Either
None
(default to original size) or tuple of ints(img_height, img_width)
. interpolation
- Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used.
Returns
A PIL Image instance.
Raises
ImportError
- if PIL is not available.
ValueError
- if interpolation method is not supported.
Expand source code
@keras_export('keras.utils.load_img', 'keras.preprocessing.image.load_img') def load_img(path, grayscale=False, color_mode='rgb', target_size=None, interpolation='nearest'): """Loads an image into PIL format. Usage: ``` image = tf.keras.preprocessing.image.load_img(image_path) input_arr = tf.keras.preprocessing.image.img_to_array(image) input_arr = np.array([input_arr]) # Convert single image to a batch. predictions = model.predict(input_arr) ``` Args: path: Path to image file. grayscale: DEPRECATED use `color_mode="grayscale"`. color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". The desired image format. target_size: Either `None` (default to original size) or tuple of ints `(img_height, img_width)`. interpolation: Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used. Returns: A PIL Image instance. Raises: ImportError: if PIL is not available. ValueError: if interpolation method is not supported. """ return image.load_img(path, grayscale=grayscale, color_mode=color_mode, target_size=target_size, interpolation=interpolation)
def model_to_dot(model, show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, subgraph=False, layer_range=None)
-
Convert a Keras model to dot format.
Args
model
- A Keras model instance.
show_shapes
- whether to display shape information.
show_dtype
- whether to display layer dtypes.
show_layer_names
- whether to display layer names.
rankdir
rankdir
argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.expand_nested
- whether to expand nested models into clusters.
dpi
- Dots per inch.
subgraph
- whether to return a
pydot.Cluster
instance. layer_range
- input of
list
containing twostr
items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which thepydot.Dot
will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches tolayer_range[0]
and the end predicate will be the last element it matches tolayer_range[1]
. By defaultNone
which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete.
Returns
A
pydot.Dot
instance representing the Keras model or apydot.Cluster
instance representing nested model ifsubgraph=True
.Raises
ImportError
- if graphviz or pydot are not available.
Expand source code
@keras_export('keras.utils.model_to_dot') def model_to_dot(model, show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, subgraph=False, layer_range=None): """Convert a Keras model to dot format. Args: model: A Keras model instance. show_shapes: whether to display shape information. show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot. expand_nested: whether to expand nested models into clusters. dpi: Dots per inch. subgraph: whether to return a `pydot.Cluster` instance. layer_range: input of `list` containing two `str` items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the `pydot.Dot` will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to `layer_range[0]` and the end predicate will be the last element it matches to `layer_range[1]`. By default `None` which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete. Returns: A `pydot.Dot` instance representing the Keras model or a `pydot.Cluster` instance representing nested model if `subgraph=True`. Raises: ImportError: if graphviz or pydot are not available. """ from keras.layers import wrappers from keras.engine import sequential from keras.engine import functional if not check_pydot(): message = ( 'You must install pydot (`pip install pydot`) ' 'and install graphviz ' '(see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.') if 'IPython.core.magics.namespace' in sys.modules: # We don't raise an exception here in order to avoid crashing notebook # tests where graphviz is not available. print(message) return else: raise ImportError(message) if subgraph: dot = pydot.Cluster(style='dashed', graph_name=model.name) dot.set('label', model.name) dot.set('labeljust', 'l') else: dot = pydot.Dot() dot.set('rankdir', rankdir) dot.set('concentrate', True) dot.set('dpi', dpi) dot.set_node_defaults(shape='record') if layer_range: if len(layer_range) != 2: raise ValueError('layer_range must be of shape (2,)') if (not isinstance(layer_range[0], str) or not isinstance(layer_range[1], str)): raise ValueError('layer_range should contain string type only') layer_range = get_layer_index_bound_by_layer_name(model, layer_range) if layer_range[0] < 0 or layer_range[1] > len(model.layers): raise ValueError('Both values in layer_range should be in', 'range (%d, %d)' % (0, len(model.layers))) sub_n_first_node = {} sub_n_last_node = {} sub_w_first_node = {} sub_w_last_node = {} layers = model.layers if not model._is_graph_network: node = pydot.Node(str(id(model)), label=model.name) dot.add_node(node) return dot elif isinstance(model, sequential.Sequential): if not model.built: model.build() layers = super(sequential.Sequential, model).layers # Create graph nodes. for i, layer in enumerate(layers): if (layer_range) and (i < layer_range[0] or i > layer_range[1]): continue layer_id = str(id(layer)) # Append a wrapped layer's label to node's label, if it exists. layer_name = layer.name class_name = layer.__class__.__name__ if isinstance(layer, wrappers.Wrapper): if expand_nested and isinstance(layer.layer, functional.Functional): submodel_wrapper = model_to_dot( layer.layer, show_shapes, show_dtype, show_layer_names, rankdir, expand_nested, subgraph=True) # sub_w : submodel_wrapper sub_w_nodes = submodel_wrapper.get_nodes() sub_w_first_node[layer.layer.name] = sub_w_nodes[0] sub_w_last_node[layer.layer.name] = sub_w_nodes[-1] dot.add_subgraph(submodel_wrapper) else: layer_name = '{}({})'.format(layer_name, layer.layer.name) child_class_name = layer.layer.__class__.__name__ class_name = '{}({})'.format(class_name, child_class_name) if expand_nested and isinstance(layer, functional.Functional): submodel_not_wrapper = model_to_dot( layer, show_shapes, show_dtype, show_layer_names, rankdir, expand_nested, subgraph=True) # sub_n : submodel_not_wrapper sub_n_nodes = submodel_not_wrapper.get_nodes() sub_n_first_node[layer.name] = sub_n_nodes[0] sub_n_last_node[layer.name] = sub_n_nodes[-1] dot.add_subgraph(submodel_not_wrapper) # Create node's label. if show_layer_names: label = '{}: {}'.format(layer_name, class_name) else: label = class_name # Rebuild the label as a table including the layer's dtype. if show_dtype: def format_dtype(dtype): if dtype is None: return '?' else: return str(dtype) label = '%s|%s' % (label, format_dtype(layer.dtype)) # Rebuild the label as a table including input/output shapes. if show_shapes: def format_shape(shape): return str(shape).replace(str(None), 'None') try: outputlabels = format_shape(layer.output_shape) except AttributeError: outputlabels = '?' if hasattr(layer, 'input_shape'): inputlabels = format_shape(layer.input_shape) elif hasattr(layer, 'input_shapes'): inputlabels = ', '.join( [format_shape(ishape) for ishape in layer.input_shapes]) else: inputlabels = '?' label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels, outputlabels) if not expand_nested or not isinstance( layer, functional.Functional): node = pydot.Node(layer_id, label=label) dot.add_node(node) # Connect nodes with edges. for i, layer in enumerate(layers): if (layer_range) and (i <= layer_range[0] or i > layer_range[1]): continue layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): node_key = layer.name + '_ib-' + str(i) if node_key in model._network_nodes: for inbound_layer in tf.nest.flatten(node.inbound_layers): inbound_layer_id = str(id(inbound_layer)) if not expand_nested: assert dot.get_node(inbound_layer_id) assert dot.get_node(layer_id) add_edge(dot, inbound_layer_id, layer_id) else: # if inbound_layer is not Model or wrapped Model if (not isinstance(inbound_layer, functional.Functional) and not is_wrapped_model(inbound_layer)): # if current layer is not Model or wrapped Model if (not isinstance(layer, functional.Functional) and not is_wrapped_model(layer)): assert dot.get_node(inbound_layer_id) assert dot.get_node(layer_id) add_edge(dot, inbound_layer_id, layer_id) # if current layer is Model elif isinstance(layer, functional.Functional): add_edge(dot, inbound_layer_id, sub_n_first_node[layer.name].get_name()) # if current layer is wrapped Model elif is_wrapped_model(layer): add_edge(dot, inbound_layer_id, layer_id) name = sub_w_first_node[layer.layer.name].get_name() add_edge(dot, layer_id, name) # if inbound_layer is Model elif isinstance(inbound_layer, functional.Functional): name = sub_n_last_node[inbound_layer.name].get_name() if isinstance(layer, functional.Functional): output_name = sub_n_first_node[layer.name].get_name() add_edge(dot, name, output_name) else: add_edge(dot, name, layer_id) # if inbound_layer is wrapped Model elif is_wrapped_model(inbound_layer): inbound_layer_name = inbound_layer.layer.name add_edge(dot, sub_w_last_node[inbound_layer_name].get_name(), layer_id) return dot
def normalize(x, axis=-1, order=2)
-
Normalizes a Numpy array.
Args
x
- Numpy array to normalize.
axis
- axis along which to normalize.
order
- Normalization order (e.g.
order=2
for L2 norm).
Returns
A normalized copy of the array.
Expand source code
@keras_export('keras.utils.normalize') def normalize(x, axis=-1, order=2): """Normalizes a Numpy array. Args: x: Numpy array to normalize. axis: axis along which to normalize. order: Normalization order (e.g. `order=2` for L2 norm). Returns: A normalized copy of the array. """ l2 = np.atleast_1d(np.linalg.norm(x, order, axis)) l2[l2 == 0] = 1 return x / np.expand_dims(l2, axis)
def plot_model(model, to_file='model.png', show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, layer_range=None)
-
Converts a Keras model to dot format and save to a file.
Example:
input = tf.keras.Input(shape=(100,), dtype='int32', name='input') x = tf.keras.layers.Embedding( output_dim=512, input_dim=10000, input_length=100)(input) x = tf.keras.layers.LSTM(32)(x) x = tf.keras.layers.Dense(64, activation='relu')(x) x = tf.keras.layers.Dense(64, activation='relu')(x) x = tf.keras.layers.Dense(64, activation='relu')(x) output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) model = tf.keras.Model(inputs=[input], outputs=[output]) dot_img_file = '/tmp/model_1.png' tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
Args
model
- A Keras model instance
to_file
- File name of the plot image.
show_shapes
- whether to display shape information.
show_dtype
- whether to display layer dtypes.
show_layer_names
- whether to display layer names.
rankdir
rankdir
argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.expand_nested
- Whether to expand nested models into clusters.
dpi
- Dots per inch.
layer_range
- input of
list
containing twostr
items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the plot will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches tolayer_range[0]
and the end predicate will be the last element it matches tolayer_range[1]
. By defaultNone
which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete.
Returns
A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.
Expand source code
@keras_export('keras.utils.plot_model') def plot_model(model, to_file='model.png', show_shapes=False, show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96, layer_range=None): """Converts a Keras model to dot format and save to a file. Example: ```python input = tf.keras.Input(shape=(100,), dtype='int32', name='input') x = tf.keras.layers.Embedding( output_dim=512, input_dim=10000, input_length=100)(input) x = tf.keras.layers.LSTM(32)(x) x = tf.keras.layers.Dense(64, activation='relu')(x) x = tf.keras.layers.Dense(64, activation='relu')(x) x = tf.keras.layers.Dense(64, activation='relu')(x) output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) model = tf.keras.Model(inputs=[input], outputs=[output]) dot_img_file = '/tmp/model_1.png' tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) ``` Args: model: A Keras model instance to_file: File name of the plot image. show_shapes: whether to display shape information. show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot. expand_nested: Whether to expand nested models into clusters. dpi: Dots per inch. layer_range: input of `list` containing two `str` items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the plot will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to `layer_range[0]` and the end predicate will be the last element it matches to `layer_range[1]`. By default `None` which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete. Returns: A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks. """ dot = model_to_dot( model, show_shapes=show_shapes, show_dtype=show_dtype, show_layer_names=show_layer_names, rankdir=rankdir, expand_nested=expand_nested, dpi=dpi, layer_range=layer_range) to_file = path_to_string(to_file) if dot is None: return _, extension = os.path.splitext(to_file) if not extension: extension = 'png' else: extension = extension[1:] # Save image to disk. dot.write(to_file, format=extension) # Return the image as a Jupyter Image object, to be displayed in-line. # Note that we cannot easily detect whether the code is running in a # notebook, and thus we always return the Image if Jupyter is available. if extension != 'pdf': try: from IPython import display return display.Image(filename=to_file) except ImportError: pass
def register_keras_serializable(package='Custom', name=None)
-
Registers an object with the Keras serialization framework.
This decorator injects the decorated class or function into the Keras custom object dictionary, so that it can be serialized and deserialized without needing an entry in the user-provided custom object dict. It also injects a function that Keras will call to get the object's serializable string key.
Note that to be serialized and deserialized, classes must implement the
get_config()
method. Functions do not have this requirement.The object will be registered under the key 'package>name' where
name
, defaults to the object name if not passed.Args
package
- The package that this class belongs to.
name
- The name to serialize this class under in this package. If None, the class' name will be used.
Returns
A decorator that registers the decorated class with the passed names.
Expand source code
@keras_export('keras.utils.register_keras_serializable') def register_keras_serializable(package='Custom', name=None): """Registers an object with the Keras serialization framework. This decorator injects the decorated class or function into the Keras custom object dictionary, so that it can be serialized and deserialized without needing an entry in the user-provided custom object dict. It also injects a function that Keras will call to get the object's serializable string key. Note that to be serialized and deserialized, classes must implement the `get_config()` method. Functions do not have this requirement. The object will be registered under the key 'package>name' where `name`, defaults to the object name if not passed. Args: package: The package that this class belongs to. name: The name to serialize this class under in this package. If None, the class' name will be used. Returns: A decorator that registers the decorated class with the passed names. """ def decorator(arg): """Registers a class with the Keras serialization framework.""" class_name = name if name is not None else arg.__name__ registered_name = package + '>' + class_name if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): raise ValueError( 'Cannot register a class that does not have a get_config() method.') if registered_name in _GLOBAL_CUSTOM_OBJECTS: raise ValueError( '%s has already been registered to %s' % (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) if arg in _GLOBAL_CUSTOM_NAMES: raise ValueError('%s has already been registered to %s' % (arg, _GLOBAL_CUSTOM_NAMES[arg])) _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg _GLOBAL_CUSTOM_NAMES[arg] = registered_name return arg return decorator
def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs)
-
Saves an image stored as a Numpy array to a path or file object.
Args
path
- Path or file object.
x
- Numpy array.
data_format
- Image data format, either "channels_first" or "channels_last".
file_format
- Optional file format override. If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used.
scale
- Whether to rescale image values to be within
[0, 255]
. **kwargs
- Additional keyword arguments passed to
PIL.Image.save()
.
Expand source code
@keras_export('keras.utils.save_img', 'keras.preprocessing.image.save_img') def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): """Saves an image stored as a Numpy array to a path or file object. Args: path: Path or file object. x: Numpy array. data_format: Image data format, either "channels_first" or "channels_last". file_format: Optional file format override. If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used. scale: Whether to rescale image values to be within `[0, 255]`. **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ if data_format is None: data_format = backend.image_data_format() image.save_img(path, x, data_format=data_format, file_format=file_format, scale=scale, **kwargs)
def serialize_keras_object(instance)
-
Serialize a Keras object into a JSON-compatible representation.
Calls to
serialize_keras_object()
while underneath theSharedObjectSavingScope
context manager will cause any objects re-used across multiple layers to be saved with a special shared object ID. This allows the network to be re-created properly during deserialization.Args
instance
- The object to serialize.
Returns
A dict-like, JSON-compatible representation of the object's config.
Expand source code
@keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): """Serialize a Keras object into a JSON-compatible representation. Calls to `serialize_keras_object` while underneath the `SharedObjectSavingScope` context manager will cause any objects re-used across multiple layers to be saved with a special shared object ID. This allows the network to be re-created properly during deserialization. Args: instance: The object to serialize. Returns: A dict-like, JSON-compatible representation of the object's config. """ _, instance = tf.__internal__.decorator.unwrap(instance) if instance is None: return None # pylint: disable=protected-access # # For v1 layers, checking supports_masking is not enough. We have to also # check whether compute_mask has been overridden. supports_masking = (getattr(instance, 'supports_masking', False) or (hasattr(instance, 'compute_mask') and not is_default(instance.compute_mask))) if supports_masking and is_default(instance.get_config): warnings.warn('Custom mask layers require a config and must override ' 'get_config. When loading, the custom mask layer must be ' 'passed to the custom_objects argument.', category=CustomMaskWarning) # pylint: enable=protected-access if hasattr(instance, 'get_config'): name = get_registered_name(instance.__class__) try: config = instance.get_config() except NotImplementedError as e: if _SKIP_FAILED_SERIALIZATION: return serialize_keras_class_and_config( name, {_LAYER_UNDEFINED_CONFIG_KEY: True}) raise e serialization_config = {} for key, item in config.items(): if isinstance(item, str): serialization_config[key] = item continue # Any object of a different type needs to be converted to string or dict # for serialization (e.g. custom functions, custom classes) try: serialized_item = serialize_keras_object(item) if isinstance(serialized_item, dict) and not isinstance(item, dict): serialized_item['__passive_serialization__'] = True serialization_config[key] = serialized_item except ValueError: serialization_config[key] = item name = get_registered_name(instance.__class__) return serialize_keras_class_and_config( name, serialization_config, instance) if hasattr(instance, '__name__'): return get_registered_name(instance) raise ValueError('Cannot serialize', instance)
def to_categorical(y, num_classes=None, dtype='float32')
-
Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
Args
y
- class vector to be converted into a matrix (integers from 0 to num_classes).
num_classes
- total number of classes. If
None
, this would be inferred as the (largest number iny
) + 1. dtype
- The data type expected by the input. Default:
'float32'
.
Returns
A binary matrix representation of the input. The classes axis is placed last. Example:
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) >>> a = tf.constant(a, shape=[4, 4]) >>> print(a) tf.Tensor( [[1. 0. 0. 0.] [0. 1. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)
>>> b = tf.constant([.9, .04, .03, .03, ... .3, .45, .15, .13, ... .04, .01, .94, .05, ... .12, .21, .5, .17], ... shape=[4, 4]) >>> loss = tf.keras.backend.categorical_crossentropy(a, b) >>> print(np.around(loss, 5)) [0.10536 0.82807 0.1011 1.77196]
>>> loss = tf.keras.backend.categorical_crossentropy(a, a) >>> print(np.around(loss, 5)) [0. 0. 0. 0.]
Raises
Value Error
- If input contains string value
Expand source code
@keras_export('keras.utils.to_categorical') def to_categorical(y, num_classes=None, dtype='float32'): """Converts a class vector (integers) to binary class matrix. E.g. for use with categorical_crossentropy. Args: y: class vector to be converted into a matrix (integers from 0 to num_classes). num_classes: total number of classes. If `None`, this would be inferred as the (largest number in `y`) + 1. dtype: The data type expected by the input. Default: `'float32'`. Returns: A binary matrix representation of the input. The classes axis is placed last. Example: >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) >>> a = tf.constant(a, shape=[4, 4]) >>> print(a) tf.Tensor( [[1. 0. 0. 0.] [0. 1. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.]], shape=(4, 4), dtype=float32) >>> b = tf.constant([.9, .04, .03, .03, ... .3, .45, .15, .13, ... .04, .01, .94, .05, ... .12, .21, .5, .17], ... shape=[4, 4]) >>> loss = tf.keras.backend.categorical_crossentropy(a, b) >>> print(np.around(loss, 5)) [0.10536 0.82807 0.1011 1.77196] >>> loss = tf.keras.backend.categorical_crossentropy(a, a) >>> print(np.around(loss, 5)) [0. 0. 0. 0.] Raises: Value Error: If input contains string value """ y = np.array(y, dtype='int') input_shape = y.shape if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: input_shape = tuple(input_shape[:-1]) y = y.ravel() if not num_classes: num_classes = np.max(y) + 1 n = y.shape[0] categorical = np.zeros((n, num_classes), dtype=dtype) categorical[np.arange(n), y] = 1 output_shape = input_shape + (num_classes,) categorical = np.reshape(categorical, output_shape) return categorical
Classes
class CustomObjectScope (*args)
-
Exposes custom classes/functions to Keras deserialization internals.
Under a scope
with CustomObjectScope(objects_dict)
, Keras methods such astf.keras.models.load_model
ortf.keras.models.model_from_config
will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric).Example:
Consider a custom regularizer
my_regularizer
:layer = Dense(3, kernel_regularizer=my_regularizer) config = layer.get_config() # Config contains a reference to `my_regularizer` ... # Later: with custom_object_scope({'my_regularizer': my_regularizer}): layer = Dense.from_config(config)
Args
*args
- Dictionary or dictionaries of
{name: object}
pairs.
Expand source code
class CustomObjectScope(object): """Exposes custom classes/functions to Keras deserialization internals. Under a scope `with custom_object_scope(objects_dict)`, Keras methods such as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric). Example: Consider a custom regularizer `my_regularizer`: ```python layer = Dense(3, kernel_regularizer=my_regularizer) config = layer.get_config() # Config contains a reference to `my_regularizer` ... # Later: with custom_object_scope({'my_regularizer': my_regularizer}): layer = Dense.from_config(config) ``` Args: *args: Dictionary or dictionaries of `{name: object}` pairs. """ def __init__(self, *args): self.custom_objects = args self.backup = None def __enter__(self): self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() for objects in self.custom_objects: _GLOBAL_CUSTOM_OBJECTS.update(objects) return self def __exit__(self, *args, **kwargs): _GLOBAL_CUSTOM_OBJECTS.clear() _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
class custom_object_scope (*args)
-
Exposes custom classes/functions to Keras deserialization internals.
Under a scope
with CustomObjectScope(objects_dict)
, Keras methods such astf.keras.models.load_model
ortf.keras.models.model_from_config
will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric).Example:
Consider a custom regularizer
my_regularizer
:layer = Dense(3, kernel_regularizer=my_regularizer) config = layer.get_config() # Config contains a reference to `my_regularizer` ... # Later: with custom_object_scope({'my_regularizer': my_regularizer}): layer = Dense.from_config(config)
Args
*args
- Dictionary or dictionaries of
{name: object}
pairs.
Expand source code
class CustomObjectScope(object): """Exposes custom classes/functions to Keras deserialization internals. Under a scope `with custom_object_scope(objects_dict)`, Keras methods such as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` will be able to deserialize any custom object referenced by a saved config (e.g. a custom layer or metric). Example: Consider a custom regularizer `my_regularizer`: ```python layer = Dense(3, kernel_regularizer=my_regularizer) config = layer.get_config() # Config contains a reference to `my_regularizer` ... # Later: with custom_object_scope({'my_regularizer': my_regularizer}): layer = Dense.from_config(config) ``` Args: *args: Dictionary or dictionaries of `{name: object}` pairs. """ def __init__(self, *args): self.custom_objects = args self.backup = None def __enter__(self): self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() for objects in self.custom_objects: _GLOBAL_CUSTOM_OBJECTS.update(objects) return self def __exit__(self, *args, **kwargs): _GLOBAL_CUSTOM_OBJECTS.clear() _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
class GeneratorEnqueuer (generator, use_multiprocessing=False, random_seed=None)
-
Builds a queue out of a data generator.
The provided generator can be finite in which case the class will throw a
StopIteration
exception.Args
generator
- a generator function which yields data
use_multiprocessing
- use multiprocessing if True, otherwise threading
random_seed
- Initial seed for workers, will be incremented by one for each worker.
Expand source code
class GeneratorEnqueuer(SequenceEnqueuer): """Builds a queue out of a data generator. The provided generator can be finite in which case the class will throw a `StopIteration` exception. Args: generator: a generator function which yields data use_multiprocessing: use multiprocessing if True, otherwise threading random_seed: Initial seed for workers, will be incremented by one for each worker. """ def __init__(self, generator, use_multiprocessing=False, random_seed=None): super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing) self.random_seed = random_seed def _get_executor_init(self, workers): """Gets the Pool initializer for multiprocessing. Args: workers: Number of works. Returns: A Function to initialize the pool """ def pool_fn(seqs): pool = get_pool_class(True)( workers, initializer=init_pool_generator, initargs=(seqs, self.random_seed, get_worker_id_queue())) _DATA_POOLS.add(pool) return pool return pool_fn def _run(self): """Submits request to the executor and queue the `Future` objects.""" self._send_sequence() # Share the initial generator with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: while True: if self.stop_signal.is_set(): return self.queue.put( executor.apply_async(next_sample, (self.uid,)), block=True) def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. Yields: The next element in the queue, i.e. a tuple `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ try: while self.is_running(): inputs = self.queue.get(block=True).get() self.queue.task_done() if inputs is not None: yield inputs except StopIteration: # Special case for finite generators last_ones = [] while self.queue.qsize() > 0: last_ones.append(self.queue.get(block=True)) # Wait for them to complete for f in last_ones: f.wait() # Keep the good ones last_ones = [future.get() for future in last_ones if future.successful()] for inputs in last_ones: if inputs is not None: yield inputs except Exception as e: # pylint: disable=broad-except self.stop() if 'generator already executing' in str(e): raise RuntimeError( 'Your generator is NOT thread-safe. ' 'Keras requires a thread-safe generator when ' '`use_multiprocessing=False, workers > 1`. ') raise e
Ancestors
Methods
def get(self)
-
Creates a generator to extract data from the queue.
Skip the data if it is
None
.Yields
The next element in the queue, i.e. a tuple
(inputs, targets)
or(inputs, targets, sample_weights)
.Expand source code
def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. Yields: The next element in the queue, i.e. a tuple `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ try: while self.is_running(): inputs = self.queue.get(block=True).get() self.queue.task_done() if inputs is not None: yield inputs except StopIteration: # Special case for finite generators last_ones = [] while self.queue.qsize() > 0: last_ones.append(self.queue.get(block=True)) # Wait for them to complete for f in last_ones: f.wait() # Keep the good ones last_ones = [future.get() for future in last_ones if future.successful()] for inputs in last_ones: if inputs is not None: yield inputs except Exception as e: # pylint: disable=broad-except self.stop() if 'generator already executing' in str(e): raise RuntimeError( 'Your generator is NOT thread-safe. ' 'Keras requires a thread-safe generator when ' '`use_multiprocessing=False, workers > 1`. ') raise e
Inherited members
class OrderedEnqueuer (sequence, use_multiprocessing=False, shuffle=False)
-
Builds a Enqueuer from a Sequence.
Args
sequence
- A
tf.keras.utils.data_utils.Sequence
object. use_multiprocessing
- use multiprocessing if True, otherwise threading
shuffle
- whether to shuffle the data at the beginning of each epoch
Expand source code
class OrderedEnqueuer(SequenceEnqueuer): """Builds a Enqueuer from a Sequence. Args: sequence: A `tf.keras.utils.data_utils.Sequence` object. use_multiprocessing: use multiprocessing if True, otherwise threading shuffle: whether to shuffle the data at the beginning of each epoch """ def __init__(self, sequence, use_multiprocessing=False, shuffle=False): super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) self.shuffle = shuffle def _get_executor_init(self, workers): """Gets the Pool initializer for multiprocessing. Args: workers: Number of workers. Returns: Function, a Function to initialize the pool """ def pool_fn(seqs): pool = get_pool_class(True)( workers, initializer=init_pool_generator, initargs=(seqs, None, get_worker_id_queue())) _DATA_POOLS.add(pool) return pool return pool_fn def _wait_queue(self): """Wait for the queue to be empty.""" while True: time.sleep(0.1) if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): return def _run(self): """Submits request to the executor and queue the `Future` objects.""" sequence = list(range(len(self.sequence))) self._send_sequence() # Share the initial sequence while True: if self.shuffle: random.shuffle(sequence) with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: for i in sequence: if self.stop_signal.is_set(): return self.queue.put( executor.apply_async(get_index, (self.uid, i)), block=True) # Done with the current epoch, waiting for the final batches self._wait_queue() if self.stop_signal.is_set(): # We're done return # Call the internal on epoch end. self.sequence.on_epoch_end() self._send_sequence() # Update the pool def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. Yields: The next element in the queue, i.e. a tuple `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ while self.is_running(): try: inputs = self.queue.get(block=True, timeout=5).get() if self.is_running(): self.queue.task_done() if inputs is not None: yield inputs except queue.Empty: pass except Exception as e: # pylint: disable=broad-except self.stop() raise e
Ancestors
Methods
def get(self)
-
Creates a generator to extract data from the queue.
Skip the data if it is
None
.Yields
The next element in the queue, i.e. a tuple
(inputs, targets)
or(inputs, targets, sample_weights)
.Expand source code
def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. Yields: The next element in the queue, i.e. a tuple `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ while self.is_running(): try: inputs = self.queue.get(block=True, timeout=5).get() if self.is_running(): self.queue.task_done() if inputs is not None: yield inputs except queue.Empty: pass except Exception as e: # pylint: disable=broad-except self.stop() raise e
Inherited members
class Progbar (target, width=30, verbose=1, interval=0.05, stateful_metrics=None, unit_name='step')
-
Displays a progress bar.
Args
target
- Total number of steps expected, None if unknown.
width
- Progress bar width on screen.
verbose
- Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics
- Iterable of string names of metrics that should not be averaged over time. Metrics in this list will be displayed as-is. All others will be averaged by the progbar before display.
interval
- Minimum visual progress update interval (in seconds).
unit_name
- Display name for step counts (usually "step" or "sample").
Expand source code
class Progbar(object): """Displays a progress bar. Args: target: Total number of steps expected, None if unknown. width: Progress bar width on screen. verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) stateful_metrics: Iterable of string names of metrics that should *not* be averaged over time. Metrics in this list will be displayed as-is. All others will be averaged by the progbar before display. interval: Minimum visual progress update interval (in seconds). unit_name: Display name for step counts (usually "step" or "sample"). """ def __init__(self, target, width=30, verbose=1, interval=0.05, stateful_metrics=None, unit_name='step'): self.target = target self.width = width self.verbose = verbose self.interval = interval self.unit_name = unit_name if stateful_metrics: self.stateful_metrics = set(stateful_metrics) else: self.stateful_metrics = set() self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()) or 'ipykernel' in sys.modules or 'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ) self._total_width = 0 self._seen_so_far = 0 # We use a dict + list to avoid garbage collection # issues found in OrderedDict self._values = {} self._values_order = [] self._start = time.time() self._last_update = 0 self._time_after_first_step = None def update(self, current, values=None, finalize=None): """Updates the progress bar. Args: current: Index of current step. values: List of tuples: `(name, value_for_last_step)`. If `name` is in `stateful_metrics`, `value_for_last_step` will be displayed as-is. Else, an average of the metric over time will be displayed. finalize: Whether this is the last update for the progress bar. If `None`, defaults to `current >= self.target`. """ if finalize is None: if self.target is None: finalize = False else: finalize = current >= self.target values = values or [] for k, v in values: if k not in self._values_order: self._values_order.append(k) if k not in self.stateful_metrics: # In the case that progress bar doesn't have a target value in the first # epoch, both on_batch_end and on_epoch_end will be called, which will # cause 'current' and 'self._seen_so_far' to have the same value. Force # the minimal value to 1 here, otherwise stateful_metric will be 0s. value_base = max(current - self._seen_so_far, 1) if k not in self._values: self._values[k] = [v * value_base, value_base] else: self._values[k][0] += v * value_base self._values[k][1] += value_base else: # Stateful metrics output a numeric value. This representation # means "take an average from a single value" but keeps the # numeric formatting. self._values[k] = [v, 1] self._seen_so_far = current now = time.time() info = ' - %.0fs' % (now - self._start) if self.verbose == 1: if now - self._last_update < self.interval and not finalize: return prev_total_width = self._total_width if self._dynamic_display: sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\r') else: sys.stdout.write('\n') if self.target is not None: numdigits = int(np.log10(self.target)) + 1 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: bar += ('=' * (prog_width - 1)) if current < self.target: bar += '>' else: bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' else: bar = '%7d/Unknown' % current self._total_width = len(bar) sys.stdout.write(bar) time_per_unit = self._estimate_step_duration(current, now) if self.target is None or finalize: if time_per_unit >= 1 or time_per_unit == 0: info += ' %.0fs/%s' % (time_per_unit, self.unit_name) elif time_per_unit >= 1e-3: info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) else: info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) else: eta = time_per_unit * (self.target - current) if eta > 3600: eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) else: eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format for k in self._values_order: info += ' - %s:' % k if isinstance(self._values[k], list): avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: info += ' %s' % self._values[k] self._total_width += len(info) if prev_total_width > self._total_width: info += (' ' * (prev_total_width - self._total_width)) if finalize: info += '\n' sys.stdout.write(info) sys.stdout.flush() elif self.verbose == 2: if finalize: numdigits = int(np.log10(self.target)) + 1 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) info = count + info for k in self._values_order: info += ' - %s:' % k avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg info += '\n' sys.stdout.write(info) sys.stdout.flush() self._last_update = now def add(self, n, values=None): self.update(self._seen_so_far + n, values) def _estimate_step_duration(self, current, now): """Estimate the duration of a single step. Given the step number `current` and the corresponding time `now` this function returns an estimate for how long a single step takes. If this is called before one step has been completed (i.e. `current == 0`) then zero is given as an estimate. The duration estimate ignores the duration of the (assumed to be non-representative) first step for estimates when more steps are available (i.e. `current>1`). Args: current: Index of current step. now: The current time. Returns: Estimate of the duration of a single step. """ if current: # there are a few special scenarios here: # 1) somebody is calling the progress bar without ever supplying step 1 # 2) somebody is calling the progress bar and supplies step one mulitple # times, e.g. as part of a finalizing call # in these cases, we just fall back to the simple calculation if self._time_after_first_step is not None and current > 1: time_per_unit = (now - self._time_after_first_step) / (current - 1) else: time_per_unit = (now - self._start) / current if current == 1: self._time_after_first_step = now return time_per_unit else: return 0 def _update_stateful_metrics(self, stateful_metrics): self.stateful_metrics = self.stateful_metrics.union(stateful_metrics)
Methods
def add(self, n, values=None)
-
Expand source code
def add(self, n, values=None): self.update(self._seen_so_far + n, values)
def update(self, current, values=None, finalize=None)
-
Updates the progress bar.
Args
current
- Index of current step.
values
- List of tuples:
(name, value_for_last_step)
. Ifname
is instateful_metrics
,value_for_last_step
will be displayed as-is. Else, an average of the metric over time will be displayed. finalize
- Whether this is the last update for the progress bar. If
None
, defaults tocurrent >= self.target
.
Expand source code
def update(self, current, values=None, finalize=None): """Updates the progress bar. Args: current: Index of current step. values: List of tuples: `(name, value_for_last_step)`. If `name` is in `stateful_metrics`, `value_for_last_step` will be displayed as-is. Else, an average of the metric over time will be displayed. finalize: Whether this is the last update for the progress bar. If `None`, defaults to `current >= self.target`. """ if finalize is None: if self.target is None: finalize = False else: finalize = current >= self.target values = values or [] for k, v in values: if k not in self._values_order: self._values_order.append(k) if k not in self.stateful_metrics: # In the case that progress bar doesn't have a target value in the first # epoch, both on_batch_end and on_epoch_end will be called, which will # cause 'current' and 'self._seen_so_far' to have the same value. Force # the minimal value to 1 here, otherwise stateful_metric will be 0s. value_base = max(current - self._seen_so_far, 1) if k not in self._values: self._values[k] = [v * value_base, value_base] else: self._values[k][0] += v * value_base self._values[k][1] += value_base else: # Stateful metrics output a numeric value. This representation # means "take an average from a single value" but keeps the # numeric formatting. self._values[k] = [v, 1] self._seen_so_far = current now = time.time() info = ' - %.0fs' % (now - self._start) if self.verbose == 1: if now - self._last_update < self.interval and not finalize: return prev_total_width = self._total_width if self._dynamic_display: sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\r') else: sys.stdout.write('\n') if self.target is not None: numdigits = int(np.log10(self.target)) + 1 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: bar += ('=' * (prog_width - 1)) if current < self.target: bar += '>' else: bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' else: bar = '%7d/Unknown' % current self._total_width = len(bar) sys.stdout.write(bar) time_per_unit = self._estimate_step_duration(current, now) if self.target is None or finalize: if time_per_unit >= 1 or time_per_unit == 0: info += ' %.0fs/%s' % (time_per_unit, self.unit_name) elif time_per_unit >= 1e-3: info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) else: info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) else: eta = time_per_unit * (self.target - current) if eta > 3600: eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) else: eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format for k in self._values_order: info += ' - %s:' % k if isinstance(self._values[k], list): avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: info += ' %s' % self._values[k] self._total_width += len(info) if prev_total_width > self._total_width: info += (' ' * (prev_total_width - self._total_width)) if finalize: info += '\n' sys.stdout.write(info) sys.stdout.flush() elif self.verbose == 2: if finalize: numdigits = int(np.log10(self.target)) + 1 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) info = count + info for k in self._values_order: info += ' - %s:' % k avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg info += '\n' sys.stdout.write(info) sys.stdout.flush() self._last_update = now
class Sequence
-
Base object for fitting to a sequence of data, such as a dataset.
Every
Sequence
must implement the__getitem__
and the__len__
methods. If you want to modify your dataset between epochs you may implementon_epoch_end
. The method__getitem__
should return a complete batch.Notes:
Sequence
are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.Examples:
from skimage.io import imread from skimage.transform import resize import numpy as np import math # Here, `x_set` is list of path to the images # and `y_set` are the associated classes. class CIFAR10Sequence(Sequence): def __init__(self, x_set, y_set, batch_size): self.x, self.y = x_set, y_set self.batch_size = batch_size def __len__(self): return math.ceil(len(self.x) / self.batch_size) def __getitem__(self, idx): batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size] batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size] return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y)
Expand source code
class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. Every `Sequence` must implement the `__getitem__` and the `__len__` methods. If you want to modify your dataset between epochs you may implement `on_epoch_end`. The method `__getitem__` should return a complete batch. Notes: `Sequence` are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators. Examples: ```python from skimage.io import imread from skimage.transform import resize import numpy as np import math # Here, `x_set` is list of path to the images # and `y_set` are the associated classes. class CIFAR10Sequence(Sequence): def __init__(self, x_set, y_set, batch_size): self.x, self.y = x_set, y_set self.batch_size = batch_size def __len__(self): return math.ceil(len(self.x) / self.batch_size) def __getitem__(self, idx): batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size] batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size] return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y) ``` """ @abstractmethod def __getitem__(self, index): """Gets batch at position `index`. Args: index: position of the batch in the Sequence. Returns: A batch """ raise NotImplementedError @abstractmethod def __len__(self): """Number of batch in the Sequence. Returns: The number of batches in the Sequence. """ raise NotImplementedError def on_epoch_end(self): """Method called at the end of every epoch. """ pass def __iter__(self): """Create a generator that iterate over the Sequence.""" for item in (self[i] for i in range(len(self))): yield item
Subclasses
Methods
def on_epoch_end(self)
-
Method called at the end of every epoch.
Expand source code
def on_epoch_end(self): """Method called at the end of every epoch. """ pass
class SequenceEnqueuer (sequence, use_multiprocessing=False)
-
Base class to enqueue inputs.
The task of an Enqueuer is to use parallelism to speed up preprocessing. This is done with processes or threads.
Example:
enqueuer = SequenceEnqueuer(...) enqueuer.start() datas = enqueuer.get() for data in datas: # Use the inputs; training, evaluating, predicting. # ... stop sometime. enqueuer.stop()
The
enqueuer.get()
should be an infinite stream of datas.Expand source code
class SequenceEnqueuer(object): """Base class to enqueue inputs. The task of an Enqueuer is to use parallelism to speed up preprocessing. This is done with processes or threads. Example: ```python enqueuer = SequenceEnqueuer(...) enqueuer.start() datas = enqueuer.get() for data in datas: # Use the inputs; training, evaluating, predicting. # ... stop sometime. enqueuer.stop() ``` The `enqueuer.get()` should be an infinite stream of datas. """ def __init__(self, sequence, use_multiprocessing=False): self.sequence = sequence self.use_multiprocessing = use_multiprocessing global _SEQUENCE_COUNTER if _SEQUENCE_COUNTER is None: try: _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) except OSError: # In this case the OS does not allow us to use # multiprocessing. We resort to an int # for enqueuer indexing. _SEQUENCE_COUNTER = 0 if isinstance(_SEQUENCE_COUNTER, int): self.uid = _SEQUENCE_COUNTER _SEQUENCE_COUNTER += 1 else: # Doing Multiprocessing.Value += x is not process-safe. with _SEQUENCE_COUNTER.get_lock(): self.uid = _SEQUENCE_COUNTER.value _SEQUENCE_COUNTER.value += 1 self.workers = 0 self.executor_fn = None self.queue = None self.run_thread = None self.stop_signal = None def is_running(self): return self.stop_signal is not None and not self.stop_signal.is_set() def start(self, workers=1, max_queue_size=10): """Starts the handler's workers. Args: workers: Number of workers. max_queue_size: queue size (when full, workers could block on `put()`) """ if self.use_multiprocessing: self.executor_fn = self._get_executor_init(workers) else: # We do not need the init since it's threads. self.executor_fn = lambda _: get_pool_class(False)(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() self.run_thread = threading.Thread(target=self._run) self.run_thread.daemon = True self.run_thread.start() def _send_sequence(self): """Sends current Iterable to all workers.""" # For new processes that may spawn _SHARED_SEQUENCES[self.uid] = self.sequence def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. Should be called by the same thread which called `start()`. Args: timeout: maximum time to wait on `thread.join()` """ self.stop_signal.set() with self.queue.mutex: self.queue.queue.clear() self.queue.unfinished_tasks = 0 self.queue.not_full.notify() self.run_thread.join(timeout) _SHARED_SEQUENCES[self.uid] = None def __del__(self): if self.is_running(): self.stop() @abstractmethod def _run(self): """Submits request to the executor and queue the `Future` objects.""" raise NotImplementedError @abstractmethod def _get_executor_init(self, workers): """Gets the Pool initializer for multiprocessing. Args: workers: Number of workers. Returns: Function, a Function to initialize the pool """ raise NotImplementedError @abstractmethod def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. # Returns Generator yielding tuples `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ raise NotImplementedError
Subclasses
Methods
def get(self)
-
Creates a generator to extract data from the queue.
Skip the data if it is
None
.Returns
Generator yielding tuples <code>(inputs, targets)</code> or <code>(inputs, targets, sample\_weights)</code>.
Expand source code
@abstractmethod def get(self): """Creates a generator to extract data from the queue. Skip the data if it is `None`. # Returns Generator yielding tuples `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ raise NotImplementedError
def is_running(self)
-
Expand source code
def is_running(self): return self.stop_signal is not None and not self.stop_signal.is_set()
def start(self, workers=1, max_queue_size=10)
-
Starts the handler's workers.
Args
workers
- Number of workers.
max_queue_size
- queue size
(when full, workers could block on
put()
)
Expand source code
def start(self, workers=1, max_queue_size=10): """Starts the handler's workers. Args: workers: Number of workers. max_queue_size: queue size (when full, workers could block on `put()`) """ if self.use_multiprocessing: self.executor_fn = self._get_executor_init(workers) else: # We do not need the init since it's threads. self.executor_fn = lambda _: get_pool_class(False)(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() self.run_thread = threading.Thread(target=self._run) self.run_thread.daemon = True self.run_thread.start()
def stop(self, timeout=None)
-
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called
start()
.Args
timeout
- maximum time to wait on
thread.join()
Expand source code
def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. Should be called by the same thread which called `start()`. Args: timeout: maximum time to wait on `thread.join()` """ self.stop_signal.set() with self.queue.mutex: self.queue.queue.clear() self.queue.unfinished_tasks = 0 self.queue.not_full.notify() self.run_thread.join(timeout) _SHARED_SEQUENCES[self.uid] = None