Source code for toast.data

# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved.  Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import re
from collections import OrderedDict
from collections.abc import MutableMapping

import numpy as np

from .accelerator import AcceleratorObject, accel_enabled
from .mpi import Comm
from .utils import Logger


[docs]class Data(MutableMapping): """Class which represents distributed data A Data object contains a list of observations assigned to each process group in the Comm. Args: comm (:class:`toast.Comm`): The toast Comm class for distributing the data. view (bool): If True, do not explicitly clear observation data on deletion. """ def __init__(self, comm=Comm(), view=False): self._comm = comm self._view = view self.obs = [] """The list of observations. """ self._internal = dict() def __getitem__(self, key): return self._internal[key] def __delitem__(self, key): del self._internal[key] def __setitem__(self, key, value): self._internal[key] = value def __iter__(self): return iter(self._internal) def __len__(self): return len(self._internal) def __repr__(self): val = "<Data with {} Observations:\n".format(len(self.obs)) for ob in self.obs: val += "{}\n".format(ob) val += "Metadata:\n" val += "{}".format(self._internal) val += "\n>" return val def __del__(self): if hasattr(self, "obs"): self.clear() @property def comm(self): """The toast.Comm over which the data is distributed.""" return self._comm
[docs] def clear(self): """Clear the list of observations.""" if not self._view: self.accel_clear() for ob in self.obs: ob.clear() self.obs.clear() if not self._view: self._internal.clear() return
[docs] def all_local_detectors(self, selection=None): """Get the superset of local detectors in all observations. This builds up the result from calling `select_local_detectors()` on all observations. Returns: (list): The list of all local detectors across all observations. """ all_dets = OrderedDict() for ob in self.obs: dets = ob.select_local_detectors(selection=selection) for d in dets: if d not in all_dets: all_dets[d] = None return list(all_dets.keys())
[docs] def detector_units(self, det_data): """Get the detector data units for a given field. This verifies that the specified detector data field has the same units in all observations where it occurs, and returns that unit. Args: det_data (str): The detector data field. Returns: (Unit): The unit used across all observations. """ log = Logger.get() local_units = None for ob in self.obs: if det_data not in ob.detdata: continue ob_units = ob.detdata[det_data].units if local_units is None: local_units = ob_units else: if ob_units != local_units: msg = f"obs {ob.name} detdata {det_data} units " msg += f"{ob_units} != {local_units}" log.error(msg) raise RuntimeError(msg) if self.comm.comm_world is None: det_units = local_units else: det_units = self.comm.comm_world.gather(local_units, root=0) if self.comm.world_rank == 0: for dtu in det_units: if dtu != local_units: msg = f"observations have different units " msg += f"{dtu} != {local_units}" log.error(msg) raise RuntimeError(msg) # We know that every process is the same now det_units = local_units return det_units
[docs] def info(self, handle=None): """Print information about the distributed data. Information is written to the specified file handle. Only the rank 0 process writes. Args: handle (descriptor): file descriptor supporting the write() method. If None, use print(). Returns: None """ # Each process group gathers their output groupstr = "" procstr = "" gcomm = self._comm.comm_group wcomm = self._comm.comm_world rcomm = self._comm.comm_group_rank if wcomm is None: msg = "Data distributed over a single process (no MPI)" if handle is None: print(msg, flush=True) else: handle.write(msg) else: if wcomm.rank == 0: msg = "Data distributed over {} processes in {} groups\n".format( self._comm.world_size, self._comm.ngroups ) if handle is None: print(msg, flush=True) else: handle.write(msg) def _get_optional(k, dt): if k in dt: return dt[k] else: return None for ob in self.obs: if self._comm.group_rank == 0: groupstr = "{}{}\n".format(groupstr, str(ob)) # The world rank 0 process collects output from all groups and # writes to the handle recvgrp = "" if self._comm.world_rank == 0: if handle is None: print(groupstr, flush=True) else: handle.write(groupstr) if wcomm is not None: for g in range(1, self._comm.ngroups): if wcomm.rank == 0: recvgrp = rcomm.recv(source=g, tag=g) if handle is None: print(recvgrp, flush=True) else: handle.write(recvgrp) elif g == self._comm.group: if gcomm.rank == 0: rcomm.send(groupstr, dest=0, tag=g) wcomm.barrier() return
[docs] def split( self, obs_index=False, obs_name=False, obs_uid=False, obs_session_name=False, obs_key=None, require_full=False, ): """Split the Data object. Create new Data objects that have views into unique subsets of the observations (the observations are not copied). Only one "criteria" may be used to perform this splitting operation. The observations may be split by index in the original list, by name, by UID, by session, or by the value of a specified key. The new Data objects are returned in a dictionary whose keys are the value of the selection criteria (index, name, uid, or value of the key). Any observation that cannot be placed (because it is missing a name, uid or key) will be ignored and not added to any of the returned Data objects. If the `require_full` parameter is set to True, such situations will raise an exception. Args: obs_index (bool): If True, split by index in original list of observations. obs_name (bool): If True, split by observation name. obs_uid (bool): If True, split by observation UID. obs_session_name (bool): If True, split by session name. obs_key (str): Split by values of this observation key. require_full (bool): If True, every observation must be placed in the output. Returns: (OrderedDict): The dictionary of new Data objects. """ log = Logger.get() check = ( int(obs_index) + int(obs_name) + int(obs_uid) + int(obs_session_name) + int(obs_key is not None) ) if check == 0 or check > 1: raise RuntimeError("You must specify exactly one split criteria") datasplit = OrderedDict() group_rank = self.comm.group_rank group_comm = self.comm.comm_group if obs_index: # Splitting by (unique) index for iob, ob in enumerate(self.obs): newdat = Data(comm=self._comm, view=True) newdat._internal = self._internal newdat.obs.append(ob) datasplit[iob] = newdat elif obs_name: # Splitting by (unique) name for iob, ob in enumerate(self.obs): if ob.name is None: if require_full: msg = f"require_full is True, but observation {iob} has no name" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) else: newdat = Data(comm=self._comm, view=True) newdat._internal = self._internal newdat.obs.append(ob) datasplit[ob.name] = newdat elif obs_uid: # Splitting by UID for iob, ob in enumerate(self.obs): if ob.uid is None: if require_full: msg = f"require_full is True, but observation {iob} has no UID" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) else: newdat = Data(comm=self._comm, view=True) newdat._internal = self._internal newdat.obs.append(ob) datasplit[ob.uid] = newdat elif obs_session_name: # Splitting by (non-unique) session name for iob, ob in enumerate(self.obs): if ob.session is None or ob.session.name is None: if require_full: msg = f"require_full is True, but observation {iob} has no session name" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) else: sname = ob.session.name if sname not in datasplit: newdat = Data(comm=self._comm, view=True) newdat._internal = self._internal datasplit[sname] = newdat datasplit[sname].obs.append(ob) elif obs_key is not None: # Splitting by arbitrary key. Unlike name / uid which are built it to the # observation class, arbitrary keys might be modified in different ways # across all processes in a group. For this reason, we do an additional # check for consistent values across the process group. for iob, ob in enumerate(self.obs): if obs_key not in ob: if require_full: msg = f"require_full is True, but observation {iob} has no key '{obs_key}'" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) else: obs_val = ob[obs_key] # Get the values from all processes in the group group_vals = None if group_comm is None: group_vals = [obs_val] else: group_vals = group_comm.allgather(obs_val) if group_vals.count(group_vals[0]) != len(group_vals): msg = f"observation {iob}, key '{obs_key}' has inconsistent values across processes" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) if obs_val not in datasplit: newdat = Data(comm=self._comm, view=True) newdat._internal = self._internal datasplit[obs_val] = newdat datasplit[obs_val].obs.append(ob) return datasplit
[docs] def select( self, obs_index=None, obs_name=None, obs_uid=None, obs_session_name=None, obs_key=None, obs_val=None, ): """Create a new Data object with a subset of observations. The returned Data object just has a view of the original observations (they are not copied). The list of observations in the new Data object is a logical OR of the criteria passed in: * Index location in the original list of observations * Name of the observation * UID of the observation * Session of the observation * Existence of the specified dictionary key * Required value of the specified dictionary key Args: obs_index (int): Observation location in the original list. obs_name (str): The observation name or a compiled regular expression object to use for matching. obs_uid (int): The observation UID to select. obs_session_name (str): The name of the session. obs_key (str): The observation dictionary key to examine. obs_val (str): The required value of the observation dictionary key or a compiled regular expression object to use for matching. Returns: (Data): A new Data object with references to the orginal metadata and a subset of observations. """ log = Logger.get() if obs_val is not None and obs_key is None: raise RuntimeError("If you specify obs_val, you must also specify obs_key") group_rank = self.comm.group_rank group_comm = self.comm.comm_group new_data = Data(comm=self._comm, view=True) # Use a reference to the original metadata new_data._internal = self._internal for iob, ob in enumerate(self.obs): if obs_index is not None and obs_index == iob: new_data.obs.append(ob) continue if obs_name is not None and ob.name is not None: if isinstance(obs_name, re.Pattern): if obs_name.match(ob.name) is not None: new_data.obs.append(ob) continue elif obs_name == ob.name: new_data.obs.append(ob) continue if obs_uid is not None and ob.uid is not None and obs_uid == ob.uid: new_data.obs.append(ob) continue if ( obs_session_name is not None and ob.session is not None and obs_session_name == ob.session.name ): new_data.obs.append(ob) continue if obs_key is not None and obs_key in ob: # Get the values from all processes in the group and check # for consistency. group_vals = None if group_comm is None: group_vals = [ob[obs_key]] else: group_vals = group_comm.allgather(ob[obs_key]) if group_vals.count(group_vals[0]) != len(group_vals): msg = f"observation {iob}, key '{obs_key}' has inconsistent values across processes" log.error_rank(msg, comm=group_comm) raise RuntimeError(msg) if obs_val is None: # We have the key, and are accepting any value new_data.obs.append(ob) continue elif isinstance(obs_val, re.Pattern): if obs_val.match(ob[obs_key]) is not None: # Matches our regex new_data.obs.append(ob) continue elif obs_val == ob[obs_key]: new_data.obs.append(ob) continue return new_data
# Accelerator use
[docs] def accel_create(self, names): """Create a set of data objects on the device. This takes a dictionary with the same format as those used by the Operator provides() and requires() methods. If the data already exists on the device then no action is taken. Args: names (dict): Dictionary of lists. Returns: None """ log = Logger.get() if not accel_enabled(): log.verbose(f"accel_enabled is False, canceling accel_create.") return for ob in self.obs: for objname, objmgr in [ ("detdata", ob.detdata), ("shared", ob.shared), ("intervals", ob.intervals), ]: for key in names[objname]: if key not in objmgr: msg = f"ob {ob.name} {objname} accel_create '{key}' " msg += f"not present, ignoring" log.verbose(msg) continue if objmgr.accel_exists(key): msg = f"ob {ob.name} {objname}: accel_create '{key}'" msg += f" already exists" log.verbose(msg) else: log.verbose(f"ob {ob.name} {objname}: accel_create '{key}'") objmgr.accel_create(key) for key in names["global"]: val = self._internal.get(key, None) if isinstance(val, AcceleratorObject): if not val.accel_exists(): log.verbose(f"Data accel_create: '{key}'") val.accel_create(key) else: log.verbose(f"Data accel_create: '{key}' already on device") else: log.verbose( f"Data accel_create: '{key}' ({type(val)}) is not an AcceleratorObject" )
[docs] def accel_update_device(self, names): """Copy a set of data objects to the device. This takes a dictionary with the same format as those used by the Operator provides() and requires() methods. Args: names (dict): Dictionary of lists. Returns: None """ if not accel_enabled(): return log = Logger.get() for ob in self.obs: for objname, objmgr in [ ("detdata", ob.detdata), ("shared", ob.shared), ("intervals", ob.intervals), ]: for key in names[objname]: if key not in objmgr: msg = f"ob {ob.name} {objname} update_device key '{key}'" msg += f" not present, ignoring" log.verbose(msg) continue if not objmgr.accel_exists(key): msg = f"ob {ob.name} {objname} update_device key '{key}'" msg += f" does not exist on accelerator" log.error(msg) raise RuntimeError(msg) if objmgr.accel_in_use(key): msg = f"ob {ob.name} {objname}: skip update_device '{key}'" msg += f" already in use" log.verbose(msg) else: log.verbose(f"ob {ob.name} {objname}: update_device '{key}'") objmgr.accel_update_device(key) for key in names["global"]: val = self._internal.get(key, None) if isinstance(val, AcceleratorObject): if val.accel_in_use(): msg = f"Skipping update_device for '{key}', " msg += "device data in use" log.verbose(msg) else: log.verbose(f"Calling Data update_device for '{key}'") val.accel_update_device() else: msg = f"Data accel_update_device: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg)
[docs] def accel_update_host(self, names): """Copy a set of data objects to the host. This takes a dictionary with the same format as those used by the Operator provides() and requires() methods. Args: names (dict): Dictionary of lists. Returns: None """ if not accel_enabled(): return log = Logger.get() for ob in self.obs: for objname, objmgr in [ ("detdata", ob.detdata), ("shared", ob.shared), ("intervals", ob.intervals), ]: for key in names[objname]: if key not in objmgr: msg = f"ob {ob.name} {objname} update_host key '{key}'" msg += f" not present, ignoring" log.verbose(msg) continue if not objmgr.accel_exists(key): msg = f"ob {ob.name} {objname} update_host key '{key}'" msg += f" does not exist on accelerator, ignoring" log.verbose(msg) continue if not objmgr.accel_in_use(key): msg = f"ob {ob.name} {objname}: skip update_host, '{key}'" msg += f" already on host" log.verbose(msg) else: log.verbose(f"ob {ob.name} {objname}: update_host '{key}'") objmgr.accel_update_host(key) for key in names["global"]: val = self._internal.get(key, None) if isinstance(val, AcceleratorObject): if not val.accel_in_use(): msg = f"Skipping update_host for '{key}', " msg += "host data already in use" log.verbose(msg) else: log.verbose(f"Calling Data update_host for '{key}'") val.accel_update_host() else: msg = f"Data accel_update_host: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg)
[docs] def accel_delete(self, names): """Delete a specific set of device objects This takes a dictionary with the same format as those used by the Operator provides() and requires() methods. Args: names (dict): Dictionary of lists. Returns: None """ if not accel_enabled(): return log = Logger.get() for ob in self.obs: for objname, objmgr in [ ("detdata", ob.detdata), ("shared", ob.shared), ("intervals", ob.intervals), ]: for key in names[objname]: if key not in objmgr: msg = f"ob {ob.name} {objname} accel_delete key '{key}'" msg += f" not present, ignoring" log.verbose(msg) continue if objmgr.accel_exists(key): log.verbose(f"ob {ob.name} {objname}: accel_delete '{key}'") objmgr.accel_delete(key) else: msg = f"ob {ob.name} {objname}: accel_delete '{key}'" msg += f" not present on device" log.verbose(msg) for key in names["global"]: val = self._internal.get(key, None) if isinstance(val, AcceleratorObject): if val.accel_exists(): log.verbose(f"Calling Data accel_delete for '{key}'") val.accel_delete() else: msg = f"Data accel_delete: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg)
[docs] def accel_clear(self): """Delete all accelerator data.""" if not accel_enabled(): return log = Logger.get() for ob in self.obs: ob.accel_clear() for key, val in self._internal.items(): if isinstance(val, AcceleratorObject): if val.accel_exists(): val.accel_delete() else: msg = f"Data accel_clear: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg)