# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.
import re
from time import time
import numpy as np
from .._libtoast import filter_polynomial
from ..op import Operator
from ..timing import function_timer
from .. import qarray as qa
XAXIS, YAXIS, ZAXIS = np.eye(3)
class OpPolyFilter2D(Operator):
"""Operator to regress out 2D polynomials across the focal plane."""
def __init__(
self,
order=1,
pattern=r".*",
name=None,
common_flag_name=None,
common_flag_mask=255,
flag_name=None,
flag_mask=255,
poly_flag_mask=1,
intervals="intervals",
):
self._order = order
self._nmode = (order + 1) * (order + 2) // 2
self._pattern = pattern
self._name = name
self._common_flag_name = common_flag_name
self._common_flag_mask = common_flag_mask
self._flag_name = flag_name
self._flag_mask = flag_mask
self._poly_flag_mask = poly_flag_mask
self._intervals = intervals
# Call the parent class constructor.
super().__init__()
@function_timer
def exec(self, data):
"""Apply the 2D polynomial filter to the signal.
Args:
data (toast.Data): The distributed data.
"""
norder = self._order + 1
nmode = self._nmode
for obs in data.obs:
t0 = time()
t_template = 0
t_get_norm = 0
t_apply_norm = 0
t_solve = 0
t_clean = 0
tod = obs["tod"]
times = tod.local_times()
comm = tod.grid_comm_row
detectors = tod.detectors
ndet = len(detectors)
detector_index = {}
pat = re.compile(self._pattern)
ndet = 0
for det in detectors:
if pat.match(det) is None:
continue
detector_index[det] = ndet
ndet += 1
# Number of detectors may limit the number of modes we can constrain
nmode = min(self._nmode, ndet)
focalplane = obs["focalplane"]
detector_templates = np.zeros([ndet, nmode])
mode = 0
xorders = np.zeros(nmode)
yorders = np.zeros(nmode)
for order in range(norder):
for yorder in range(order + 1):
xorder = order - yorder
xorders[mode] = xorder
yorders[mode] = yorder
mode += 1
if mode == nmode:
break
if mode == nmode:
break
for det in tod.local_dets:
if det not in detector_index:
continue
idet = detector_index[det]
det_quat = focalplane[det]["quat"]
x, y, z = qa.rotate(det_quat, ZAXIS)
theta, phi = np.arcsin([x, y])
detector_templates[idet] = theta ** xorders * phi ** yorders
if self._intervals in obs:
intervals = obs[self._intervals]
else:
intervals = None
local_intervals = tod.local_intervals(intervals)
if len(local_intervals) == 0:
# No intervals to filter
continue
common_ref = tod.local_common_flags(self._common_flag_name)
# Iterate over each interval
for ival in local_intervals:
ind = slice(ival.first, ival.last + 1)
nsample = ival.last - ival.first + 1
templates = np.zeros([ndet, nmode, nsample])
proj = np.zeros([nmode, nsample])
t1 = time()
norms = np.zeros(nmode)
for det in tod.local_dets:
if det not in detector_index:
continue
idet = detector_index[det]
ref = tod.local_signal(det, self._name)[ind]
flag_ref = tod.local_flags(det, self._flag_name)[ind]
flg = common_ref[ind] & self._common_flag_mask
flg |= flag_ref & self._flag_mask
mask = flg == 0
# We might want to remove the interval mean if the
# data were not already 1D-filtered
# ref -= np.mean(ref[mask])
template = detector_templates[idet]
templates[idet] = np.outer(template, mask)
proj += np.outer(template, ref * mask)
norms += template ** 2
del ref
del flag_ref
t_template += time() - t1
t1 = time()
comm.allreduce(templates)
comm.allreduce(proj)
comm.allreduce(norms)
good = norms != 0
norms[good] = norms[good] ** -0.5
t_get_norm += time() - t1
t1 = time()
templates = np.transpose(
templates, [1, 0, 2]
).copy() # nmode x ndet x nsample
for mode, norm in enumerate(norms):
if norm:
templates[mode] *= norm
proj[mode] *= norm
t_apply_norm += time() - t1
t1 = time()
templates = np.transpose(
templates, [2, 1, 0]
).copy() # nsample x ndet x nmode
proj = proj.T.copy() # nsample x nmode
coeff = np.zeros([nsample, nmode])
for isample in range(nsample):
if isample % comm.size != comm.rank:
continue
templatesT = templates[isample].T.copy() # ndet x nmode
ccinv = np.dot(templatesT, templates[isample])
try:
cc = np.linalg.inv(ccinv)
coeff[isample] = np.dot(cc, proj[isample])
except np.linalg.LinAlgError:
coeff[isample] = 0
comm.allreduce(coeff)
t_solve += time() - t1
t1 = time()
"""
for isample in range(nsample):
if np.all(coeff[isample] == 0):
common_ref[isample + ival.first] |= self._poly_flag_mask
continue
for det in tod.local_dets:
if det not in detector_index:
continue
idet = detector_index[det]
ref = tod.local_signal(det, self._name)[ind]
ref[isample] -= np.dot(coeff[isample], templates[isample, idet])
"""
for isample in range(nsample):
if np.all(coeff[isample] == 0):
common_ref[isample + ival.first] |= self._poly_flag_mask
templates = np.transpose(
templates, [1, 2, 0]
).copy() # ndet x nmode x nsample
coeff = coeff.T.copy() # nmode x nsample
for det in tod.local_dets:
if det not in detector_index:
continue
idet = detector_index[det]
ref = tod.local_signal(det, self._name)[ind]
for mode in range(nmode):
ref -= coeff[mode] * templates[idet, mode]
t_clean += time() - t1
del common_ref
"""
print(
"Time per observation: {:.1f} s\n"
" templates : {:6.1f} s\n"
" get_norm : {:6.1f} s\n"
" apply_norm : {:6.1f} s\n"
" solve : {:6.1f} s\n"
" clean : {:6.1f} s".format(
time() - t0, t_template, t_get_norm, t_apply_norm, t_solve, t_clean
),
flush=True,
)
"""
return
[docs]class OpPolyFilter(Operator):
"""Operator which applies polynomial filtering to the TOD.
This applies polynomial filtering to the valid intervals of each TOD.
Args:
order (int): Order of the filtering polynomial.
pattern (str): Regex pattern to match against detector names.
Only detectors that match the pattern are filtered.
name (str): Name of the output signal cache object will be
<name_in>_<detector>. If the object exists, it is used as
input. Otherwise signal is read using the tod read method.
common_flag_name (str): Cache name of the output common flags.
If it already exists, it is used. Otherwise flags
are read from the tod object and stored in the cache under
common_flag_name.
common_flag_mask (byte): Bitmask to use when flagging data
based on the common flags.
flag_name (str): Cache name of the output detector flags will
be <flag_name>_<detector>. If the object exists, it is
used. Otherwise flags are read from the tod object.
flag_mask (byte): Bitmask to use when flagging data
based on the detector flags.
poly_flag_mask (byte): Bitmask to use when adding flags based
on polynomial filter failures.
intervals (str): Name of the valid intervals in observation.
"""
def __init__(
self,
order=1,
pattern=r".*",
name=None,
common_flag_name=None,
common_flag_mask=255,
flag_name=None,
flag_mask=255,
poly_flag_mask=1,
intervals="intervals",
):
self._order = order
self._pattern = pattern
self._name = name
self._common_flag_name = common_flag_name
self._common_flag_mask = common_flag_mask
self._flag_name = flag_name
self._flag_mask = flag_mask
self._poly_flag_mask = poly_flag_mask
self._intervals = intervals
# Call the parent class constructor.
super().__init__()
[docs] @function_timer
def exec(self, data):
"""Apply the polynomial filter to the signal.
Args:
data (toast.Data): The distributed data.
"""
for obs in data.obs:
tod = obs["tod"]
if self._intervals in obs:
intervals = obs[self._intervals]
else:
intervals = None
local_intervals = tod.local_intervals(intervals)
if len(local_intervals) == 0:
# No intervals to filter
continue
common_ref = tod.local_common_flags(self._common_flag_name)
pat = re.compile(self._pattern)
for det in tod.local_dets:
# Test the detector pattern
if pat.match(det) is None:
continue
ref = tod.local_signal(det, self._name)
flag_ref = tod.local_flags(det, self._flag_name)
# Iterate over each interval
local_starts = []
local_stops = []
for ival in local_intervals:
local_starts.append(ival.first)
local_stops.append(ival.last)
local_starts = np.array(local_starts)
local_stops = np.array(local_stops)
flg = common_ref & self._common_flag_mask
flg |= flag_ref & self._flag_mask
filter_polynomial(self._order, flg, [ref], local_starts, local_stops)
flag_ref[flg != 0] |= self._poly_flag_mask
del ref
del flag_ref
del common_ref
return