Source code for toast.mpi

# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved.  Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import os
import time

from ._libtoast import Logger

use_mpi = None
MPI = None


# traitlets require the MPI communicator type to be an actual class,
# even if None is a valid value
class MPI_Comm:
    pass


if use_mpi is None:
    log = Logger.get()
    # See if the user has explicitly disabled MPI.
    if "MPI_DISABLE" in os.environ:
        use_mpi = False
    else:
        # Special handling for running on a NERSC login node.  This is for convenience.
        # The same behavior could be implemented with environment variables set in a
        # shell resource file.
        at_nersc = False
        if "NERSC_HOST" in os.environ:
            at_nersc = True
        in_slurm = False
        if "SLURM_JOB_NAME" in os.environ:
            in_slurm = True
        if (not at_nersc) or in_slurm:
            try:
                # related to MPI_THREAD_MULTIPLE and errors when OpenMPI 4 is used.
                # https://github.com/mpi4py/mpi4py/issues/34#issuecomment-800233017
                import mpi4py

                mpi4py.rc.thread_level = "serialized"

                import mpi4py.MPI as MPI

                use_mpi = True
                MPI_Comm = MPI.Comm
            except:
                # There could be many possible exceptions raised...
                log.debug("mpi4py not found- using serial operations only")
                use_mpi = False

# Assign each process to an accelerator device
from .accelerator import accel_assign_device, use_accel_jax, use_accel_omp

if use_accel_omp or use_accel_jax:
    node_procs = 1
    node_rank = 0
    accel_gb = 2.0
    if "TOAST_GPU_MEM_GB" in os.environ:
        accel_gb = float(os.environ["TOAST_GPU_MEM_GB"])
    if use_mpi:
        # We need to compute which process goes to which device
        nodecomm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0)
        node_procs = nodecomm.size
        node_rank = nodecomm.rank
        accel_assign_device(node_procs, node_rank, accel_gb, False)
        nodecomm.Free()
        del nodecomm
    else:
        accel_assign_device(node_procs, node_rank, accel_gb, False)
else:
    # Disabled == True
    accel_assign_device(1, 0, 0.0, True)

# We put other imports *after* the MPI check, since usually the MPI initialization
# is time sensitive and may timeout the job if it does not happen quickly enough.

import itertools
import sys
import traceback
from contextlib import contextmanager

import numpy as np
from pshmem import MPILock, MPIShared

from ._libtoast import Logger


def get_world():
    """Retrieve the default world communicator and its properties.

    If MPI is enabled, this returns MPI.COMM_WORLD and the process rank and number of
    processes.  If MPI is disabled, this returns None for the communicator, zero
    for the rank, and one for the number of processes.

    Returns:
        (tuple):  The communicator, number of processes, and rank.

    """
    rank = 0
    procs = 1
    world = None
    if use_mpi:
        world = MPI.COMM_WORLD
        rank = world.rank
        procs = world.size
    return world, procs, rank


[docs]class Comm(object): """Class which represents a two-level hierarchy of MPI communicators. A Comm object splits the full set of processes into groups of size "group". If group_size does not divide evenly into the size of the given communicator, then those processes remain idle. A Comm object stores several MPI communicators: The "world" communicator given here, which contains all processes to consider, a "group" communicator (one per group), and a "rank" communicator which contains the processes with the same group-rank across all groups. This object also stores a "node" communicator containing all processes with access to the same shared memory, and a "node rank" communicator for processes with the same rank on a node. There is a node rank communicator for all nodes and also one for within the group. Additionally, there is a mechanism for creating and caching row / column communicators for process grids within a group. If MPI is not enabled, then all communicators are set to None. Additionally, there may be cases where MPI is enabled in the environment, but the user wishes to disable it when creating a Comm object. This can be done by passing MPI.COMM_SELF as the world communicator. Args: world (mpi4py.MPI.Comm): the MPI communicator containing all processes. group (int): the size of each process group. """ def __init__(self, world=None, groupsize=0): log = Logger.get() if world is None: if use_mpi: # Default is COMM_WORLD world = MPI.COMM_WORLD else: # MPI is disabled, leave world as None. pass else: if use_mpi: # We were passed a communicator to use. Check that it is # actually a communicator, otherwise fall back to COMM_WORLD. if not isinstance(world, MPI.Comm): log.warning( "Specified world communicator is not a valid " "mpi4py.MPI.Comm object. Using COMM_WORLD." ) world = MPI.COMM_WORLD else: log.warning( "World communicator specified even though " "MPI is disabled. Ignoring this constructor " "argument." ) world = None # Special case, MPI available but the user wants a serial # data object if world == MPI.COMM_SELF: world = None self._wcomm = world self._wrank = 0 self._wsize = 1 self._nodecomm = None self._noderankcomm = None self._nodeprocs = 1 self._noderankprocs = 1 if self._wcomm is not None: self._wrank = self._wcomm.rank self._wsize = self._wcomm.size self._nodecomm = self._wcomm.Split_type(MPI.COMM_TYPE_SHARED, 0) self._nodeprocs = self._nodecomm.size myworldnode = self._wrank // self._nodeprocs self._noderankcomm = self._wcomm.Split(self._nodecomm.rank, myworldnode) self._noderankprocs = self._noderankcomm.size self._gsize = groupsize if (self._gsize < 0) or (self._gsize > self._wsize): log.warning( "Invalid groupsize ({}). Should be between {} " "and {}. Using single process group instead.".format( groupsize, 0, self._wsize ) ) self._gsize = 0 if self._gsize == 0: self._gsize = self._wsize self._ngroups = self._wsize // self._gsize if self._ngroups * self._gsize != self._wsize: msg = ( "World communicator size ({}) is not evenly divisible " "by requested group size ({}).".format(self._wsize, self._gsize) ) log.error(msg) raise RuntimeError(msg) if self._gsize > self._nodeprocs: if self._gsize % self._nodeprocs != 0: msg = f"Group size of {self._gsize} is not a whole number of " msg += f"nodes (there are {self._nodeprocs} processes per node)" log.error(msg) raise RuntimeError(msg) self._gnodes = self._gsize // self._nodeprocs else: self._gnodes = 1 self._group = self._wrank // self._gsize self._grank = self._wrank % self._gsize self._cleanup_group_comm = False self._gnodecomm = None self._gnoderankcomm = None self._gnodeprocs = 1 if self._ngroups == 1: # We just have one group with all processes. self._gcomm = self._wcomm self._gnodecomm = self._nodecomm self._gnoderankcomm = self._noderankcomm self._rcomm = None self._gnoderankprocs = self._noderankprocs else: # We need to split the communicator. This code is never executed # unless MPI is enabled and we have multiple groups. self._gcomm = self._wcomm.Split(self._group, self._grank) self._rcomm = self._wcomm.Split(self._grank, self._group) self._gnodecomm = self._gcomm.Split_type(MPI.COMM_TYPE_SHARED, 0) self._gnodeprocs = self._gnodecomm.size mygroupnode = self._grank // self._gnodeprocs self._gnoderankcomm = self._gcomm.Split(self._gnodecomm.rank, mygroupnode) self._gnoderankprocs = self._gnoderankcomm.size self._cleanup_group_comm = True msg = f"Comm on world rank {self._wrank}:\n" msg += f" world comm = {self._wcomm} with {self._wsize} ranks\n" msg += ( f" intra-node comm = {self._nodecomm} ({self._nodeprocs} ranks per node)\n" ) msg += f" inter-node rank comm = {self._noderankcomm} ({self._noderankprocs} ranks)\n" msg += ( f" in group {self._group + 1} / {self._ngroups} with rank {self._grank}\n" ) msg += f" intra-group comm = {self._gcomm} ({self._gsize} ranks)\n" msg += f" inter-group rank comm = {self._rcomm}\n" msg += f" intra-node group comm = {self._gnodecomm} ({self._gnodeprocs} ranks per node)\n" msg += f" inter-node group rank comm = {self._gnoderankcomm} ({self._noderankprocs} ranks)\n" log.verbose(msg) if self._gnoderankprocs != self._gnodes: msg = f"Number of group node rank procs ({self._gnoderankprocs}) does " msg += f"not match the number of nodes in a group ({self._gnodes})" log.error(msg) raise RuntimeError(msg) # Create a cache of row / column communicators for each group. These can # then be re-used for observations with the same grid shapes. self._rowcolcomm = dict() def close(self): # Explicitly free communicators if needed. # Go through the cache of row / column grid communicators and free if hasattr(self, "_rowcolcomm"): for process_rows, comms in self._rowcolcomm.items(): if comms["cleanup"]: # We previously allocated new communicators for this grid. # Free them now. for subcomm in [ "row_rank_node", "row_node", "col_rank_node", "col_node", "row", "col", ]: if comms[subcomm] is not None: comms[subcomm].Free() del comms[subcomm] del self._rowcolcomm # Optionally delete the group communicators if they were created. if hasattr(self, "_cleanup_group_comm") and self._cleanup_group_comm: self._gcomm.Free() self._rcomm.Free() self._gnodecomm.Free() self._gnoderankcomm.Free() del self._gcomm del self._rcomm del self._gnodecomm del self._gnoderankcomm del self._cleanup_group_comm # We always need to clean up the world node and node-rank communicators # if they exist if hasattr(self, "_noderankcomm") and self._noderankcomm is not None: self._noderankcomm.Free() del self._noderankcomm if hasattr(self, "_nodecomm") and self._nodecomm is not None: self._nodecomm.Free() del self._nodecomm return def __del__(self): self.close() @property def world_size(self): """The size of the world communicator.""" return self._wsize @property def world_rank(self): """The rank of this process in the world communicator.""" return self._wrank @property def ngroups(self): """The number of process groups.""" return self._ngroups @property def group(self): """The group containing this process.""" return self._group @property def group_size(self): """The size of the group containing this process.""" return self._gsize @property def group_rank(self): """The rank of this process in the group communicator.""" return self._grank # All the different types of relevant MPI communicators. @property def comm_world(self): """The world communicator.""" return self._wcomm @property def comm_world_node(self): """The communicator shared by world processes on the same node.""" return self._nodecomm @property def comm_world_node_rank(self): """The communicator shared by world processes with the same node rank across all nodes.""" return self._noderankcomm @property def comm_group(self): """The communicator shared by processes within this group.""" return self._gcomm @property def comm_group_rank(self): """The communicator shared by processes with the same group_rank.""" return self._rcomm @property def comm_group_node(self): """The communicator shared by group processes on the same node.""" return self._gnodecomm @property def comm_group_node_rank(self): """The communicator shared by group processes with the same node rank on nodes within the group.""" return self._gnoderankcomm
[docs] def comm_row_col(self, process_rows): """Return the row and column communicators for this group and grid shape. This function will create and / or return the communicators needed for a given process grid. The return value is a dictionary with the following keys: - "row": The row communicator. - "col": The column communicator. - "row_node": The node-local communicator within the row communicator - "col_node": The node-local communicator within the col communicator - "row_rank_node": The communicator across nodes among processes with the same node-rank within the row communicator. - "col_rank_node": The communicator across nodes among processes with the same node-rank within the column communicator. Args: process_rows (int): The number of rows in the process grid. Returns: (dict): The communicators for this grid shape. """ if process_rows not in self._rowcolcomm: # Does not exist yet, create it. if self._gcomm is None: if process_rows != 1: msg = "MPI not in use, so only process_rows == 1 is allowed" log.error(msg) raise ValueError(msg) self._rowcolcomm[process_rows] = { "row": None, "col": None, "row_node": None, "row_rank_node": None, "col_node": None, "col_rank_node": None, "cleanup": False, } else: if self._gcomm.size % process_rows != 0: msg = f"The number of process_rows ({process_rows}) " msg += f"does not divide evenly into the communicator " msg += f"size ({self._gcomm.size})" log.error(msg) raise RuntimeError(msg) process_cols = self._gcomm.size // process_rows if process_rows == 1: # We can re-use the group communicators as the grid column # communicators comm_row = self._gcomm comm_row_node = self._gnodecomm comm_row_rank_node = self._gnoderankcomm comm_col = None comm_col_node = None comm_col_rank_node = None cleanup = False elif process_cols == 1: # We can re-use the group communicators as the grid row # communicators comm_col = self._gcomm comm_col_node = self._gnodecomm comm_col_rank_node = self._gnoderankcomm comm_row = None comm_row_node = None comm_row_rank_node = None cleanup = False else: # We have to create new split communicators col_rank = self._gcomm.rank // process_cols row_rank = self._gcomm.rank % process_cols comm_row = self._gcomm.Split(col_rank, row_rank) comm_col = self._gcomm.Split(row_rank, col_rank) # Node and node-rank comms for each row and col. comm_row_node = comm_row.Split_type(MPI.COMM_TYPE_SHARED, 0) row_nodeprocs = comm_row_node.size row_node = comm_row.rank // row_nodeprocs comm_row_rank_node = comm_row.Split(comm_row_node.rank, row_node) comm_col_node = comm_col.Split_type(MPI.COMM_TYPE_SHARED, 0) col_nodeprocs = comm_col_node.size col_node = comm_col.rank // col_nodeprocs comm_col_rank_node = comm_col.Split(comm_col_node.rank, col_node) cleanup = True msg = f"Comm on world rank {self._wrank} create grid with {process_rows} rows:\n" msg += f" row comm = {comm_row}\n" msg += f" node comm = {comm_row_node}\n" msg += f" node rank comm = {comm_row_rank_node}\n" msg += f" col comm = {comm_col}\n" msg += f" node comm = {comm_col_node}\n" msg += f" node rank comm = {comm_col_rank_node}" log.verbose(msg) self._rowcolcomm[process_rows] = { "row": comm_row, "row_node": comm_row_node, "row_rank_node": comm_row_rank_node, "col": comm_col, "col_node": comm_col_node, "col_rank_node": comm_col_rank_node, "cleanup": cleanup, } return self._rowcolcomm[process_rows]
def __repr__(self): lines = [ " World MPI communicator = {}".format(self._wcomm), " World MPI size = {}".format(self._wsize), " World MPI rank = {}".format(self._wrank), " Group MPI communicator = {}".format(self._gcomm), " Group MPI size = {}".format(self._gsize), " Group MPI rank = {}".format(self._grank), " Rank MPI communicator = {}".format(self._rcomm), ] return "<toast.Comm\n{}\n>".format("\n".join(lines))
@contextmanager def exception_guard(comm=None, timeout=30): """Ensure that, if one MPI process raises an un-caught exception, the program shuts down properly. Args: comm (mpi4py.MPI.Comm): The MPI communicator or None. timeout (int): The number of seconds to wait before aborting all processes """ log = Logger.get() rank = 0 if comm is None else comm.rank try: yield except Exception: # Note that the intention of this function is to handle *any* exception. # The typical use case is to wrap main() and ensure that the job exits # cleanly. exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception(exc_type, exc_value, exc_traceback) lines = [f"Proc {rank}: {x}" for x in lines] msg = "".join(lines) log.error(msg) # kills the job if comm is None or comm.size == 1: # Raising the exception allows for debugging raise else: if comm.size > 1: # gives other processes a bit of time to see whether # they encounter the same error time.sleep(timeout) comm.Abort(1) def comm_equal(comm_a, comm_b): """Compare communicators for equality. Returns True if both communicators are None, or if they are identical (e.g. the compare as MPI.IDENT). Args: comm_a (MPI.Comm): First communicator, or None. comm_b (MPI.Comm): Second communicator, or None. Returns: (bool): The result """ if comm_a is None: if comm_b is None: return True else: return False else: if comm_b is None: return False else: fail = 0 if MPI.Comm.Compare(comm_a, comm_b) != MPI.IDENT: fail = 1 fail = comm_a.allreduce(fail, op=MPI.SUM) return fail == 0 def comm_equivalent(comm_a, comm_b): """Compare communicators. Returns True if both communicators are None, or if they have the same size and ordering of ranks. Args: comm_a (MPI.Comm): First communicator, or None. comm_b (MPI.Comm): Second communicator, or None. Returns: (bool): The result """ if comm_a is None: if comm_b is None: return True else: return False else: if comm_b is None: return False else: fail = 0 if comm_a.size != comm_b.size: fail = 1 if comm_a.rank != comm_b.rank: fail = 1 if MPI.Comm.Compare(comm_a, comm_b) not in [MPI.IDENT, MPI.CONGRUENT]: fail = 1 fail = comm_a.allreduce(fail, op=MPI.SUM) return fail == 0