# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.
import numpy as np
from .mpi import Comm
# This is effectively the "Painter's Partition Problem".
def distribute_required_groups(A, max_per_group):
ngroup = 1
total = 0
for i in range(A.shape[0]):
total += A[i]
if total > max_per_group:
total = A[i]
ngroup += 1
return ngroup
def distribute_partition(A, k):
low = np.max(A)
high = np.sum(A)
while low < high:
mid = low + int((high - low) / 2)
required = distribute_required_groups(A, mid)
if required <= k:
high = mid
else:
low = mid + 1
return low
def distribute_discrete(sizes, groups, pow=1.0, breaks=None):
"""Distribute indivisible blocks of items between groups.
Given some contiguous blocks of items which cannot be
subdivided, distribute these blocks to the specified
number of groups in a way which minimizes the maximum
total items given to any group. Optionally weight the
blocks by a power of their size when computing the
distribution.
Args:
sizes (list): The sizes of the indivisible blocks.
groups (int): The number of groups.
pow (float): The power to use for weighting
breaks (list): List of hard breaks in the data distribution.
Returns:
A list of tuples. There is one tuple per group.
The first element of the tuple is the first item
assigned to the group, and the second element is
the number of items assigned to the group.
"""
chunks = np.array(sizes, dtype=np.int64)
weights = np.power(chunks.astype(np.float64), pow)
max_per_proc = float(distribute_partition(weights.astype(np.int64), groups))
target = np.sum(weights) / groups
dist = []
off = 0
curweight = 0.0
all_breaks = None
if breaks is not None:
# Check that the problem makes sense
all_breaks = np.unique(breaks)
all_breaks = all_breaks[all_breaks > 0]
all_breaks = all_breaks[all_breaks < chunks.size]
all_breaks = np.sort(all_breaks)
if all_breaks.size + 1 > groups:
raise RuntimeError(
"Cannot divide {} chunks to {} groups with {} breaks.".format(
chunks.size, groups, all_breaks.size
)
)
at_break = False
for cur in range(0, weights.shape[0]):
if curweight + weights[cur] > max_per_proc or at_break:
dist.append((off, cur - off))
over = curweight - target
curweight = weights[cur] + over
off = cur
else:
curweight += weights[cur]
if all_breaks is not None:
at_break = False
if cur + 1 in all_breaks:
at_break = True
dist.append((off, weights.shape[0] - off))
if len(dist) != groups:
raise RuntimeError(
"Number of distributed groups different than " "number requested"
)
return dist
def distribute_uniform(totalsize, groups, breaks=None):
"""Uniformly distribute items between groups.
Given some number of items and some number of groups,
distribute the items between groups in the most Uniform
way possible.
Args:
totalsize (int): The total number of items.
groups (int): The number of groups.
breaks (list): List of hard breaks in the data distribution.
Returns:
(list): there is one tuple per group. The first element of the tuple
is the first item assigned to the group, and the second element is
the number of items assigned to the group.
"""
if breaks is not None:
all_breaks = np.unique(breaks)
all_breaks = all_breaks[all_breaks > 0]
all_breaks = all_breaks[all_breaks < totalsize]
all_breaks = np.sort(all_breaks)
if len(all_breaks) > groups - 1:
raise RuntimeError(
"Cannot distribute {} chunks with {} breaks over {} groups"
"".format(totalsize, len(all_breaks), groups)
)
groupcounts = []
groupsizes = []
offset = 0
groupsleft = groups
totalleft = totalsize
for brk in all_breaks:
length = brk - offset
groupcount = int(np.round(groupsleft * length / totalleft))
groupcount = max(1, groupcount)
groupcount = min(groupcount, groupsleft)
groupcounts.append(groupcount)
groupsizes.append(length)
groupsleft -= groupcount
totalleft -= length
offset = brk
groupcounts.append(groupsleft)
groupsizes.append(totalleft)
else:
groupcounts = [groups]
groupsizes = [totalsize]
dist = []
offset = 0
for groupsize, groupcount in zip(groupsizes, groupcounts):
for i in range(groupcount):
myn = groupsize // groupcount
off = 0
leftover = groupsize % groupcount
if i < leftover:
myn = myn + 1
off = i * myn
else:
off = ((myn + 1) * leftover) + (myn * (i - leftover))
dist.append((offset + off, myn))
offset += groupsize
return dist
def distribute_samples(
mpicomm,
detectors,
samples,
detranks=1,
detbreaks=None,
sampsizes=None,
sampbreaks=None,
):
"""Distribute data by detector and sample.
Given a list of detectors and some number of samples, distribute
the data in a load balanced way. Optionally account for constraints
on this distribution. The samples may be grouped by indivisible
chunks, and there may be forced breaks in the distribution in both
the detector and chunk directions.
samples -->
+--------------+--------------
/ | sampsize[0] | sampsize[1] ...
detrank = 0 +--------------+--------------
\ | sampsize[0] | sampsize[1] ...
+--------------+--------------
/ | sampsize[0] | sampsize[1] ...
detrank = 1 +--------------+--------------
\ | sampsize[0] | sampsize[1] ...
+--------------+--------------
| ...
Args:
mpicomm (mpi4py.MPI.Comm): the MPI communicator over which the
data is distributed. If None, then all data is assigned to a
single process.
detectors (list): The list of detector names.
samples (int): The total number of samples.
detranks (int): The dimension of the process grid in the detector
direction. The MPI communicator size must be evenly divisible
by this number.
detbreaks (list): Optional list of hard breaks in the detector
distribution.
sampsizes (list): Optional list of sample chunk sizes which
cannot be split.
sampbreaks (list): Optional list of hard breaks in the sample
distribution.
Returns:
tuple of lists: the 3 lists returned contain information about
the detector distribution, the sample distribution, and the chunk
distribution. The first list has one entry for each detrank and
contains the list of detectors for that row of the process grid.
The second list contains tuples of (first sample, N samples) for
each column of the process grid. The third list contains tuples
of (first chunk, N chunks) for each column of the process grid.
"""
nproc = 1
if mpicomm is not None:
nproc = mpicomm.size
if nproc % detranks != 0:
raise RuntimeError(
"The number of detranks ({}) does not divide evenly "
"into the number of processes ({})".format(detranks, nproc)
)
# Compute the other dimension of the process grid.
sampranks = nproc // detranks
# Distribute detectors uniformly, but respecting forced breaks in the
# grouping specified by the calling code.
dist_detsindx = distribute_uniform(len(detectors), detranks, breaks=detbreaks)
dist_dets = [detectors[d[0] : d[0] + d[1]] for d in dist_detsindx]
# Distribute samples using both the chunking and the forced breaks
if sampsizes is not None:
dist_sizes = distribute_discrete(sampsizes, sampranks, breaks=sampbreaks)
dist_samples = []
off = 0
for ds in dist_sizes:
cursamp = np.sum(sampsizes[ds[0] : ds[0] + ds[1]])
dist_samples.append((off, cursamp))
off += cursamp
else:
dist_samples = distribute_uniform(samples, sampranks, breaks=sampbreaks)
dist_sizes = [(x, 1) for x in range(sampranks)]
return (dist_dets, dist_samples, dist_sizes)
[docs]class Data(object):
"""Class which represents distributed data
A Data object contains a list of observations assigned to
each process group in the Comm.
Args:
comm (:class:`toast.Comm`): the toast Comm class for distributing the data.
"""
def __init__(self, comm=Comm()):
self._comm = comm
self.obs = []
"""The list of observations.
"""
self._metadata = {}
def __contains__(self, key):
return key in self._metadata
def __getitem__(self, key):
return self._metadata[key]
def __setitem__(self, key, value):
self._metadata[key] = value
@property
def comm(self):
"""The toast.Comm over which the data is distributed."""
return self._comm
[docs] def clear(self):
"""Clear the list of observations."""
for ob in self.obs:
ob.clear()
return
[docs] def info(self, handle=None, flag_mask=255, common_flag_mask=255, intervals=None):
"""Print information about the distributed data.
Information is written to the specified file handle. Only the rank 0
process writes. Optional flag masks are used when computing the
number of good samples.
Args:
handle (descriptor): file descriptor supporting the write()
method. If None, use print().
flag_mask (int): bit mask to use when computing the number of
good detector samples.
common_flag_mask (int): bit mask to use when computing the
number of good telescope pointings.
intervals (str): optional name of an intervals object to print
from each observation.
Returns:
None
"""
# Each process group gathers their output
groupstr = ""
procstr = ""
gcomm = self._comm.comm_group
wcomm = self._comm.comm_world
rcomm = self._comm.comm_rank
if wcomm is None:
msg = "Data distributed over a single process (no MPI)"
if handle is None:
print(msg, flush=True)
else:
handle.write(msg)
else:
if wcomm.rank == 0:
msg = "Data distributed over {} processes in {} groups\n".format(
self._comm.world_size, self._comm.ngroups
)
if handle is None:
print(msg, flush=True)
else:
handle.write(msg)
def _get_optional(k, dt):
if k in dt:
return dt[k]
else:
return None
for ob in self.obs:
id = ob["id"]
tod = _get_optional("tod", ob)
intrvl = None
if intervals is not None:
_get_optional(intervals, ob)
if self._comm.group_rank == 0:
groupstr = "observation {}:\n".format(id)
for ko in sorted(ob.keys()):
groupstr = "{} key {}\n".format(groupstr, ko)
if tod is not None:
groupstr = "{} {} total samples, {} detectors\n".format(
groupstr, tod.total_samples, len(tod.detectors)
)
if intrvl is not None:
groupstr = "{} {} intervals:\n".format(groupstr, len(intrvl))
for it in intrvl:
groupstr = "{} {} --> {} ({} --> {})\n".format(
groupstr, it.first, it.last, it.start, it.stop
)
# rank zero of the group will print general information,
# and each process will get its statistics.
procstr = " proc {}\n".format(self._comm.group_rank)
if tod is not None:
offset, nsamp = tod.local_samples
dets = tod.local_dets
my_chunks = 1
if tod.local_chunks is not None:
my_chunks = tod.local_chunks[1]
procstr = "{} sample range {} --> {} in {} chunks:\n".format(
procstr, offset, (offset + nsamp - 1), my_chunks
)
if tod.local_chunks is not None:
chkoff = tod.local_samples[0]
for chk in range(tod.local_chunks[1]):
abschk = tod.local_chunks[0] + chk
chkstart = chkoff
chkstop = chkstart + tod.total_chunks[abschk] - 1
procstr = "{} {} --> {}\n".format(
procstr, chkstart, chkstop
)
chkoff += tod.total_chunks[abschk]
if nsamp > 0:
stamps = tod.local_times()
procstr = "{} timestamps {} --> {}\n".format(
procstr, stamps[0], stamps[-1]
)
common = tod.local_common_flags()
for dt in dets:
procstr = "{} det {}:\n".format(procstr, dt)
pdata = tod.local_pointing(dt)
procstr = (
"{} pntg [{:.3e} {:.3e} {:.3e} {:.3e}] "
"--> [{:.3e} {:.3e} {:.3e} {:.3e}]\n".format(
procstr,
pdata[0, 0],
pdata[0, 1],
pdata[0, 2],
pdata[0, 3],
pdata[-1, 0],
pdata[-1, 1],
pdata[-1, 2],
pdata[-1, 3],
)
)
data = tod.local_signal(dt)
flags = tod.local_flags(dt)
procstr = "{} {:.3e} ({}) --> {:.3e} ({})\n".format(
procstr, data[0], flags[0], data[-1], flags[-1]
)
good = np.where(
((flags & flag_mask) | (common & common_flag_mask)) == 0
)[0]
procstr = "{} {} good samples\n".format(
procstr, len(good)
)
try:
min = np.min(data[good])
max = np.max(data[good])
mean = np.mean(data[good])
rms = np.std(data[good])
procstr = (
"{} min = {:.4e}, max = {:.4e},"
" mean = {:.4e}, rms = {:.4e}\n".format(
procstr, min, max, mean, rms
)
)
except FloatingPointError:
procstr = (
"{} min = N/A, max = N/A, "
"mean = N/A, rms = N/A\n".format(procstr)
)
for cname in tod.cache.keys():
procstr = "{} cache {}:\n".format(procstr, cname)
ref = tod.cache.reference(cname)
min = np.min(ref)
max = np.max(ref)
mean = np.mean(ref)
rms = np.std(ref)
procstr = (
"{} min = {:.4e}, max = {:.4e}, "
"mean = {:.4e}, rms = {:.4e}\n".format(
procstr, min, max, mean, rms
)
)
recvstr = ""
if self._comm.group_rank == 0:
groupstr = "{}{}".format(groupstr, procstr)
if gcomm is not None:
for p in range(1, self._comm.group_size):
if gcomm.rank == 0:
recvstr = gcomm.recv(source=p, tag=p)
groupstr = "{}{}".format(groupstr, recvstr)
elif p == gcomm.rank:
gcomm.send(procstr, dest=0, tag=p)
gcomm.barrier()
# the world rank 0 process collects output from all groups and
# writes to the handle
recvgrp = ""
if self._comm.world_rank == 0:
if handle is None:
print(groupstr, flush=True)
else:
handle.write(groupstr)
if wcomm is not None:
for g in range(1, self._comm.ngroups):
if wcomm.rank == 0:
recvgrp = rcomm.recv(source=g, tag=g)
if handle is None:
print(recvgrp, flush=True)
else:
handle.write(recvgrp)
elif g == self._comm.group:
if gcomm.rank == 0:
rcomm.send(groupstr, dest=0, tag=g)
wcomm.barrier()
return
[docs] def split(self, key):
"""Split the Data object.
Split the Data object based on the value of `key` in the
observation dictionary.
Args:
key(str) : Observation key to use.
Returns:
List of 2-tuples of the form (value, data)
"""
# Build a superset of all values
values = set()
for obs in self.obs:
if key not in obs:
raise RuntimeError(
'Cannot split data by "{}". Key is not '
"defined for all observations.".format(key)
)
values.add(obs[key])
all_values = None
if self._comm.comm_world is None:
all_values = [values]
else:
all_values = self._comm.comm_world.allgather(values)
for vals in all_values:
values = values.union(vals)
# Order the values alphabetically.
values = sorted(list(values))
# Split the data
datasplit = []
for value in values:
new_data = Data(comm=self._comm)
for obs in self.obs:
if obs[key] == value:
new_data.obs.append(obs)
datasplit.append((value, new_data))
return datasplit