# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.
import copy
import numbers
import sys
import types
from collections.abc import Mapping, MutableMapping, Sequence
import numpy as np
from astropy import units as u
from pshmem.utils import mpi_data_type
from .accelerator import AcceleratorObject
from .dist import distribute_samples
from .instrument import Session, Telescope
from .intervals import IntervalList, interval_dtype
from .mpi import MPI, comm_equal
from .observation_data import (
DetDataManager,
DetectorData,
IntervalsManager,
SharedDataManager,
)
from .observation_dist import DistDetSamp, redistribute_data
from .observation_view import DetDataView, SharedView, View, ViewInterface, ViewManager
from .timing import function_timer
from .utils import Logger, name_UID
default_values = None
def set_default_values(values=None):
"""Update default values for common Observation objects.
Args:
names (dict): The dictionary specifying any name overrides.
Returns:
None
"""
global default_values
defaults = {
# names
"times": "times",
"shared_flags": "flags",
"det_data": "signal",
"det_flags": "flags",
"hwp_angle": "hwp_angle",
"azimuth": "azimuth",
"elevation": "elevation",
"boresight_azel": "boresight_azel",
"boresight_radec": "boresight_radec",
"position": "position",
"velocity": "velocity",
"pixels": "pixels",
"weights": "weights",
"quats": "quats",
"quats_azel": "quats_azel",
#
# flag masks
#
"shared_mask_invalid": 1,
"shared_mask_unstable_scanrate": 2,
"shared_mask_irregular": 4,
"det_mask_invalid": 1,
"det_mask_sso": 1 + 2,
"det_mask_processing": 2,
#
# ground-specific flag masks
#
# This marks the turnarounds as "invalid". To simulate data
# in the turnarounds, set the turnaround mask in SimGround to
# just "2".
"turnaround": 1 + 2,
"scan_leftright": 8,
"scan_rightleft": 16,
"sun_up": 32,
"sun_close": 64,
"elnod": 1 + 2 + 4,
#
# ground-specific interval names
#
"scanning_interval": "scanning",
"turnaround_interval": "turnaround",
"throw_leftright_interval": "throw_leftright",
"throw_rightleft_interval": "throw_rightleft",
"throw_interval": "throw",
"scan_leftright_interval": "scan_leftright",
"scan_rightleft_interval": "scan_rightleft",
"turn_leftright_interval": "turn_leftright",
"turn_rightleft_interval": "turn_rightleft",
"elnod_interval": "elnod",
"sun_up_interval": "sun_up",
"sun_close_interval": "sun_close",
#
# Units
#
"det_data_units": u.Kelvin,
}
if values is not None:
defaults.update(values)
default_values = types.SimpleNamespace(**defaults)
if default_values is None:
set_default_values()
[docs]class Observation(MutableMapping):
"""Class representing the data for one observation.
An Observation stores information about data distribution across one or more MPI
processes and is a container for four types of objects:
* Local detector data (unique to each process).
* Shared data that has one common copy for every node spanned by the
observation.
* Intervals defining spans of data with some common characteristic.
* Other arbitrary small metadata.
Small metadata can be stored directly in the Observation using normal square
bracket "[]" access to elements (an Observation is a dictionary). Groups of
detector data (e.g. "signal", "flags", etc) can be accessed in the separate
detector data dictionary (the "detdata" attribute). Shared data can be similarly
stored in the "shared" attribute. Lists of intervals are accessed in the
"intervals" attribute and data views can use any interval list to access subsets
of detector and shared data.
**Notes on distributed use with MPI**
The detector data within an Observation is distributed among the processes in an
MPI communicator. The processes in the communicator are arranged in a rectangular
grid, with each process storing some number of detectors for a piece of time
covered by the observation. The most common configuration (and the default) is to
make this grid the size of the communicator in the "detector direction" and a size
of one in the "sample direction"::
MPI det1 sample(0), sample(1), sample(2), ...., sample(N-1)
rank 0 det2 sample(0), sample(1), sample(2), ...., sample(N-1)
----------------------------------------------------------------------
MPI det3 sample(0), sample(1), sample(2), ...., sample(N-1)
rank 1 det4 sample(0), sample(1), sample(2), ...., sample(N-1)
So each process has a subset of detectors for the whole span of the observation
time. You can override this shape by setting the process_rows to something
else. For example, process_rows=1 would result in this::
MPI rank 0 | MPI rank 1
----------------------------------+----------------------------
det1 sample(0), sample(1), ..., | ...., sample(N-1)
det2 sample(0), sample(1), ..., | ...., sample(N-1)
det3 sample(0), sample(1), ..., | ...., sample(N-1)
det4 sample(0), sample(1), ..., | ...., sample(N-1)
Args:
comm (toast.Comm): The toast communicator containing information about the
process group for this observation.
telescope (Telescope): An instance of a Telescope object.
n_samples (int): The total number of samples for this observation.
name (str): (Optional) The observation name.
uid (int): (Optional) The Unique ID for this observation. If not specified,
the UID will be computed from a hash of the name.
session (Session): The observing session that this observation is contained
in or None.
detector_sets (list): (Optional) List of lists containing detector names.
These discrete detector sets are used to distribute detectors- a detector
set will always be within a single row of the process grid. If None,
every detector is a set of one.
sample_sets (list): (Optional) List of lists of chunk sizes (integer numbers of
samples). These discrete sample sets are used to distribute sample data.
A sample set will always be within a single column of the process grid. If
None, any distribution break in the sample direction will happen at an
arbitrary place. The sum of all chunks must equal the total number of
samples.
process_rows (int): (Optional) The size of the rectangular process grid
in the detector direction. This number must evenly divide into the size of
comm. If not specified, defaults to the size of the communicator.
"""
view = ViewInterface()
@function_timer
def __init__(
self,
comm,
telescope,
n_samples,
name=None,
uid=None,
session=None,
detector_sets=None,
sample_sets=None,
process_rows=None,
):
log = Logger.get()
self._telescope = telescope
self._name = name
self._uid = uid
self._session = session
if self._uid is None and self._name is not None:
self._uid = name_UID(self._name)
if self._session is None:
if self._name is not None:
self._session = Session(
name=self._name,
uid=self._uid,
start=None,
end=None,
)
elif not isinstance(self._session, Session):
raise RuntimeError("session should be a Session instance or None")
self.dist = DistDetSamp(
n_samples,
self._telescope.focalplane.detectors,
sample_sets,
detector_sets,
comm,
process_rows,
)
# The internal metadata dictionary
self._internal = dict()
# Set up the data managers
self.detdata = DetDataManager(self.dist)
self.shared = SharedDataManager(self.dist)
self.intervals = IntervalsManager(self.dist, n_samples)
# Set up local per-detector cutting
self._detflags = {x: int(0) for x in self.dist.dets[self.dist.comm.group_rank]}
# Fully clear the observation
[docs] def clear(self):
self.view.clear()
self.intervals.clear()
self.detdata.clear()
self.shared.clear()
self._internal.clear()
# General properties
@property
def telescope(self):
"""
(Telescope): The Telescope instance for this observation.
"""
return self._telescope
@property
def name(self):
"""
(str): The name of the observation.
"""
return self._name
@property
def uid(self):
"""
(int): The Unique ID for this observation.
"""
return self._uid
@property
def session(self):
"""
(Session): The Session instance for this observation.
"""
return self._session
@property
def comm(self):
"""
(toast.Comm): The overall communicator.
"""
return self.dist.comm
# The MPI communicator along the current row of the process grid
@property
def comm_row(self):
"""
(mpi4py.MPI.Comm): The communicator for processes in the same row (or None).
"""
return self.dist.comm_row
@property
def comm_row_size(self):
"""
(int): The number of processes in the row communicator.
"""
return self.dist.comm_row_size
@property
def comm_row_rank(self):
"""
(int): The rank of this process in the row communicator.
"""
return self.dist.comm_row_rank
# The MPI communicator along the current column of the process grid
@property
def comm_col(self):
"""
(mpi4py.MPI.Comm): The communicator for processes in the same column (or None).
"""
return self.dist.comm_col
@property
def comm_col_size(self):
"""
(int): The number of processes in the column communicator.
"""
return self.dist.comm_col_size
@property
def comm_col_rank(self):
"""
(int): The rank of this process in the column communicator.
"""
return self.dist.comm_col_rank
# Detector distribution
@property
def all_detectors(self):
"""
(list): All detectors stored in this observation.
"""
return self.dist.detectors
@property
def local_detectors(self):
"""
(list): The detectors assigned to this process.
"""
return self.dist.dets[self.dist.comm.group_rank]
@property
def local_detector_flags(self):
"""(dict): The local per-detector flags"""
return self._detflags
[docs] def update_local_detector_flags(self, vals):
"""Update the per-detector flagging.
This does a bitwise OR with the existing flag values.
Args:
vals (dict): The flag values for one or more detectors.
Returns:
None
"""
ldets = set(self.local_detectors)
for k, v in vals.items():
if k not in ldets:
msg = f"Cannot update per-detector flag for '{k}', which is"
msg += " not a local detector"
raise RuntimeError(msg)
self._detflags[k] |= int(v)
[docs] def set_local_detector_flags(self, vals):
"""Set the per-detector flagging.
This resets the per-detector flags to the specified values.
Args:
vals (dict): The flag values for one or more detectors.
Returns:
None
"""
ldets = set(self.local_detectors)
for k, v in vals.items():
if k not in ldets:
msg = f"Cannot set per-detector flag for '{k}', which is"
msg += " not a local detector"
raise RuntimeError(msg)
self._detflags[k] = int(v)
[docs] def select_local_detectors(
self,
selection=None,
flagmask=(default_values.det_mask_invalid | default_values.det_mask_processing),
):
"""Get the local detectors assigned to this process.
This takes the full list of local detectors and optionally prunes them
by the specified selection and / or applies per-detector flags with
the given mask.
Args:
selection (list): Only return detectors in this set.
flagmask (uint8): Apply this mask to per-detector flags and only
include detectors with a result of zero (good).
Returns:
(list): The selected detectors.
"""
if flagmask is None:
good = set(self.local_detectors)
else:
good = set(
[
x
for x in self.local_detectors
if (self.local_detector_flags[x] & flagmask) == 0
]
)
dets = list()
if selection is None:
for det in self.local_detectors:
if det in good:
dets.append(det)
else:
sel_set = set(selection)
for det in self.local_detectors:
if (det in sel_set) and (det in good):
dets.append(det)
return dets
# Detector set distribution
@property
def all_detector_sets(self):
"""
(list): The total list of detector sets for this observation.
"""
return self.dist.detector_sets
@property
def local_detector_sets(self):
"""
(list): The detector sets assigned to this process (or None).
"""
if self.dist.detector_sets is None:
return None
else:
ds = list()
for d in range(self.dist.det_sets[self.dist.comm.group_rank].n_elem):
off = self.dist.det_sets[self.dist.comm.group_rank].offset
ds.append(self.dist.detector_sets[off + d])
return ds
# Sample distribution
@property
def n_all_samples(self):
"""(int): the total number of samples in this observation."""
return self.dist.samples
@property
def local_index_offset(self):
"""
The first sample on this process, relative to the observation start.
"""
return self.dist.samps[self.dist.comm.group_rank].offset
@property
def n_local_samples(self):
"""
The number of local samples on this process.
"""
return self.dist.samps[self.dist.comm.group_rank].n_elem
# Sample set distribution
@property
def all_sample_sets(self):
"""
(list): The input full list of sample sets used in data distribution
"""
return self.dist.sample_sets
@property
def local_sample_sets(self):
"""
(list): The sample sets assigned to this process (or None).
"""
if self.dist.sample_sets is None:
return None
else:
ss = list()
for s in range(self.dist.samp_sets[self.dist.comm.group_rank].n_elem):
off = self.dist.samp_sets[self.dist.comm.group_rank].offset
ss.append(self.dist.sample_sets[off + s])
return ss
# Mapping methods
def __getitem__(self, key):
return self._internal[key]
def __delitem__(self, key):
del self._internal[key]
def __setitem__(self, key, value):
self._internal[key] = value
def __iter__(self):
return iter(self._internal)
def __len__(self):
return len(self._internal)
def __del__(self):
if hasattr(self, "detdata"):
self.detdata.clear()
if hasattr(self, "shared"):
self.shared.clear()
def __repr__(self):
val = "<Observation"
val += f"\n name = '{self.name}'"
val += f"\n uid = '{self.uid}'"
if self.comm.comm_group is None:
val += " group has a single process (no MPI)"
else:
val += f" group has {self.comm.group_size} processes"
val += f"\n telescope = {self._telescope.__repr__()}"
val += f"\n session = {self._session.__repr__()}"
for k, v in self._internal.items():
val += f"\n {k} = {v}"
val += f"\n {self.n_all_samples} total samples ({self.n_local_samples} local)"
val += f"\n shared: {self.shared}"
val += f"\n detdata: {self.detdata}"
val += f"\n intervals: {self.intervals}"
val += "\n>"
return val
def __eq__(self, other):
# Note that testing for equality is quite expensive, since it means testing all
# metadata and also all detector, shared, and interval data. This is mainly
# used for unit tests.
log = Logger.get()
fail = 0
if self.name != other.name:
fail = 1
log.verbose(
f"Proc {self.comm.world_rank}: Obs names {self.name} != {other.name}"
)
if self.uid != other.uid:
fail = 1
log.verbose(
f"Proc {self.comm.world_rank}: Obs uid {self.uid} != {other.uid}"
)
if self.telescope != other.telescope:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs telescopes not equal")
if self.session != other.session:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs sessions not equal")
if self.dist != other.dist:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs distributions not equal")
if set(self._internal.keys()) != set(other._internal.keys()):
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs metadata keys not equal")
for k, v in self._internal.items():
if v != other._internal[k]:
feq = True
try:
feq = np.allclose(v, other._internal[k])
except Exception:
# Not floating point data
feq = False
if not feq:
fail = 1
log.verbose(
f"Proc {self.comm.world_rank}: Obs metadata[{k}]: {v} != {other[k]}"
)
break
if self.shared != other.shared:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs shared data not equal")
if self.detdata != other.detdata:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs detdata not equal")
if self.intervals != other.intervals:
fail = 1
log.verbose(f"Proc {self.comm.world_rank}: Obs intervals not equal")
if self.comm.comm_group is not None:
fail = self.comm.comm_group.allreduce(fail, op=MPI.SUM)
return fail == 0
def __ne__(self, other):
return not self.__eq__(other)
[docs] def duplicate(
self, times=None, meta=None, shared=None, detdata=None, intervals=None
):
"""Return a copy of the observation and all its data.
The times field should be the name of the shared field containing timestamps.
This is used when copying interval lists to the new observation so that these
objects reference the timestamps within this observation (rather than the old
one). If this is not specified and some intervals exist, then an exception is
raised.
The meta, shared, detdata, and intervals list specifies which of those objects
to copy to the new observation. If these are None, then all objects are
duplicated.
Args:
times (str): The name of the timestamps shared field.
meta (list): List of metadata objects to copy, or None.
shared (list): List of shared objects to copy, or None.
detdata (list): List of detdata objects to copy, or None.
intervals (list): List of intervals objects to copy, or None.
Returns:
(Observation): The new copy of the observation.
"""
log = Logger.get()
if times is None and len(self.intervals) > 0:
msg = "You must specify the times field when duplicating observations "
msg += "that have some intervals defined."
log.error(msg)
raise RuntimeError(msg)
new_obs = Observation(
self.dist.comm,
self.telescope,
self.n_all_samples,
name=self.name,
uid=self.uid,
session=self.session,
detector_sets=self.all_detector_sets,
sample_sets=self.all_sample_sets,
process_rows=self.dist.process_rows,
)
for k, v in self._internal.items():
if meta is None or k in meta:
new_obs[k] = copy.deepcopy(v)
for name, data in self.detdata.items():
if detdata is None or name in detdata:
new_obs.detdata[name] = data
copy_shared = list()
if times is not None:
copy_shared.append(times)
if shared is not None:
copy_shared.extend(shared)
for name, data in self.shared.items():
if shared is None or name in copy_shared:
# Create the object on the corresponding communicator in the new obs
new_obs.shared.assign_mpishared(name, data, self.shared.comm_type(name))
for name, data in self.intervals.items():
if intervals is None or name in intervals:
timespans = [(x.start, x.stop) for x in data]
new_obs.intervals[name] = IntervalList(
new_obs.shared[times], timespans=timespans
)
return new_obs
[docs] def memory_use(self):
"""Estimate the memory used by shared and detector data.
This sums the memory used by the shared and detdata attributes and returns the
total on all processes. This function is blocking on the observation
communicator.
Returns:
(int): The number of bytes of memory used by timestream data.
"""
# Get local memory from detector data
local_mem = self.detdata.memory_use()
# If there are many intervals, this could take up non-trivial space. Add them
# to the local total
for iname, it in self.intervals.items():
if len(it) > 0:
local_mem += len(it) * interval_dtype.itemsize
# Sum the aggregate local memory
total = None
if self.comm.comm_group is None:
total = local_mem
else:
total = self.comm.comm_group.allreduce(local_mem, op=MPI.SUM)
# The total shared memory use is already returned on every process by this
# next function.
total += self.shared.memory_use()
return total
# Redistribution
[docs] @function_timer
def redistribute(
self,
process_rows,
times=None,
override_sample_sets=False,
override_detector_sets=False,
):
"""Take the currently allocated observation and redistribute in place.
This changes the data distribution within the observation. After
re-assigning all detectors and samples, the currently allocated shared data
objects and detector data objects are redistributed using the observation
communicator.
Args:
process_rows (int): The size of the new process grid in the detector
direction. This number must evenly divide into the size of the
observation communicator.
times (str): The shared data field representing the timestamps. This
is used to recompute the intervals after redistribution.
override_sample_sets (False, None or list): If not False, override
existing sample set boundaries in the redistributed data.
override_detector_sets (False, None or list): If not False, override
existing detector set boundaries in the redistributed data.
Returns:
None
"""
log = Logger.get()
if process_rows == self.dist.process_rows:
# Nothing to do!
return
if override_sample_sets == False:
sample_sets = self.dist.sample_sets
else:
sample_sets = override_sample_sets
if override_detector_sets == False:
detector_sets = self.dist.detector_sets
else:
detector_sets = override_detector_sets
# Get the total set of per-detector flags
if self.comm_col_size == 1:
all_det_flags = self.local_detector_flags
else:
pdflags = self.comm_col.gather(self.local_detector_flags, root=0)
all_det_flags = None
if self.comm_col_rank == 0:
all_det_flags = dict()
for pf in pdflags:
all_det_flags.update(pf)
all_det_flags = self.comm_col.bcast(all_det_flags, root=0)
# Create the new distribution
new_dist = DistDetSamp(
self.dist.samples,
self._telescope.focalplane.detectors,
sample_sets,
detector_sets,
self.dist.comm,
process_rows,
)
# Do the actual redistribution
new_shr_manager, new_det_manager, new_intervals_manager = redistribute_data(
self.dist,
new_dist,
self.shared,
self.detdata,
self.intervals,
times=times,
dbg=self.name,
)
# Redistribute any metadata objects that support it.
for k, v in self._internal.items():
if hasattr(v, "redistribute"):
v.redistribute(self.dist, new_dist)
# Replace our distribution and data managers with the new ones.
del self.dist
self.dist = new_dist
self.shared.clear()
del self.shared
self.shared = new_shr_manager
self.detdata.clear()
del self.detdata
self.detdata = new_det_manager
self.intervals.clear()
del self.intervals
self.intervals = new_intervals_manager
# Restore detector flags for our new local detectors
self.set_local_detector_flags(
{x: all_det_flags[x] for x in self.local_detectors}
)
# Accelerator use
[docs] def accel_create(self, names):
"""Create a set of data objects on the device.
This takes a dictionary with the same format as those used by the Operator
provides() and requires() methods.
Args:
names (dict): Dictionary of lists.
Returns:
None
"""
for key in names["detdata"]:
self.detdata.accel_create(key)
for key in names["shared"]:
self.shared.accel_create(key)
for key in names["intervals"]:
self.intervals.accel_create(key)
for key, val in self._internal.items():
if isinstance(val, AcceleratorObject):
if not val.accel_exists():
val.accel_create()
[docs] def accel_update_device(self, names):
"""Copy data objects to the device.
This takes a dictionary with the same format as those used by the Operator
provides() and requires() methods.
Args:
names (dict): Dictionary of lists.
Returns:
None
"""
for key in names["detdata"]:
self.detdata.accel_update_device(key)
for key in names["shared"]:
self.shared.accel_update_device(key)
for key in names["intervals"]:
self.intervals.accel_update_device(key)
for key, val in self._internal.items():
if isinstance(val, AcceleratorObject):
if not val.accel_in_use():
val.accel_update_device()
[docs] def accel_update_host(self, names):
"""Copy data objects from the device.
This takes a dictionary with the same format as those used by the Operator
provides() and requires() methods.
Args:
names (dict): Dictionary of lists.
Returns:
None
"""
for key in names["detdata"]:
self.detdata.accel_update_host(key)
for key in names["shared"]:
self.shared.accel_update_host(key)
for key in names["intervals"]:
self.intervals.accel_update_host(key)
for key, val in self._internal.items():
if isinstance(val, AcceleratorObject):
if val.accel_in_use():
val.accel_update_host()
def accel_clear(self):
self.detdata.accel_clear()
self.shared.accel_clear()
self.intervals.accel_clear()
for key, val in self._internal.items():
if isinstance(val, AcceleratorObject):
if val.accel_exists():
val.accel_delete()