Source code for seek.mitigation.sum_threshold

# SEEK is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# SEEK is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with SEEK.  If not, see <http://www.gnu.org/licenses/>.


'''
Created on Jan 21, 2015

author: jakeret
'''
from __future__ import print_function, division, absolute_import, unicode_literals

import numpy as np
from scipy import ndimage

import hope

from seek.mitigation import sum_threshold_utils
from seek.utils.tod_utils import get_empty_mask
from seek.utils import filter

# Maximum neighbourhood size
MAX_PIXELS = 8

# smoothing default params
KERNEL_M = 40
KERNEL_N = 20
SIGMA_M = 7.5
SIGMA_N = 15

# dilation default params
STRUCT_SIZE = 3


@hope.jit
def _sumthreshold(data, mask, i, chi, ds0, ds1):
    """
    The operation of summing and thresholding.

    :param data: data
    :param mask: original mask
    :param i: number of iterations
    :param chi: thresholding criteria
    :param ds0: dimension of the first axis
    :param ds1: dimension of hte second axis

    :return: SumThredshold mask
    """
    tmp_mask = mask[:]
    for x in range(ds0):
        sum = 0.0
        cnt = 0
        
        for ii in range(0, i):
            if mask[x, ii] != True:
                sum += data[x, ii]
                cnt += 1
        
        for y in range(i, ds1):
            if sum > chi * cnt:
                for ii2 in range(0, i):
                    tmp_mask[x, y-ii2-1] = True
                    
            if mask[x, y] != True:
                sum += data[x, y]
                cnt += 1
            
            if mask[x, y-i] != 1:
                sum -= data[x, y-i]
                cnt -= 1
                
    return tmp_mask

def _run_sumthreshold(data, init_mask, eta, M, chi_i, sm_kwargs, plotting=True):
    """
    Perform one SumThreshold operation: sum the un-masked data after
    subtracting a smooth background and threshold it.

    :param data: data
    :param init_mask: initial mask
    :param eta: number that scales the chi value for each iteration
    :param M: number of iterations
    :param chi: thresholding criteria
    :param sm_kwargs: smoothing keyword
    :param plotting: whether to plot

    :return: SumThreshold mask

    """

    smoothed_data = filter.gaussian_filter(data, init_mask, **sm_kwargs)
    res = data-smoothed_data
    
    st_mask = init_mask.copy()

    for m, chi in zip(M, chi_i):
        chi = chi / eta
        if m==1:
            st_mask = st_mask | (chi<=res)
        else:
            st_mask = _sumthreshold(res,   st_mask,   m, chi, *res.shape)
            st_mask = _sumthreshold(res.T, st_mask.T, m, chi, *res.T.shape).T

    if plotting:
        sum_threshold_utils.plot_steps(data, st_mask, smoothed_data, res, "%s (%s)"%(eta, chi_i))
        
    return st_mask

[docs]def binary_mask_dilation(mask, struct_size_0, struct_size_1): """ Dilates the mask. :param mask: original mask :param struct_size_0: dilation parameter :param struct_size_1: dilation parameter :return: dilated mask """ struct = np.ones((struct_size_0, struct_size_1), np.bool) return ndimage.binary_dilation(mask, structure=struct, iterations=2)
[docs]def normalize(data, mask): """ Simple normalization of standing waves: subtracting the median over time for each frequency. :param data: data :param mask: mask :return: normalized data """ median = np.ma.median(np.ma.MaskedArray(data, mask), axis=1).reshape(data.shape[0], -1) data = np.abs(data - median) return data.data
[docs]def get_rfi_mask(tod, mask=None, chi_1=35000, eta_i=[0.5, 0.55, 0.62, 0.75, 1], normalize_standing_waves=True, suppress_dilation=False, plotting=True, sm_kwargs=None, di_kwargs=None): """ Computes a mask to cover the RFI in a data set. :param data: array containing the signal and RFI :param mask: the initial mask :param chi_1: First threshold :param eta_i: List of sensitivities :param normalize_standing_waves: whether to normalize standing waves :param suppress_dilation: if true, mask dilation is suppressed :param plotting: True if statistics plot should be displayed :param sm_kwargs: smoothing key words :param di_kwargs: dilation key words :return mask: the mask covering the identified RFI """ data = tod.data if mask is None: mask = get_empty_mask(data.shape) if sm_kwargs is None: sm_kwargs = get_sm_kwargs() if plotting: sum_threshold_utils.plot_moments(data) if normalize_standing_waves: data = normalize(data, mask) if plotting: sum_threshold_utils.plot_moments(data) p = 1.5 m = np.arange(1, MAX_PIXELS) M = 2**(m-1) chi_i = chi_1 / p**np.log2(m) st_mask = mask for eta in eta_i: st_mask = _run_sumthreshold(data, st_mask, eta, M, chi_i, sm_kwargs, plotting) dilated_mask = st_mask if not suppress_dilation: if di_kwargs is None: di_kwargs = get_di_kwrags() dilated_mask = binary_mask_dilation(dilated_mask - mask, **di_kwargs) if plotting: sum_threshold_utils.plot_dilation(st_mask, mask, dilated_mask) return dilated_mask+mask
[docs]def get_sm_kwargs(kernel_m=KERNEL_M, kernel_n=KERNEL_N, sigma_m=SIGMA_M, sigma_n=SIGMA_N): """ Creates a dict with the smoothing keywords. :param kernel_m: kernel window size in axis=1 :param kernel_n: kernel window size in axis=0 :param sigma_m: kernel sigma in axis=1 :param sigma_n: kernel sigma in axis=0 :return: dictionary with the smoothing keywords """ return dict(M=kernel_m, N=kernel_n, sigma_m=sigma_m, sigma_n=sigma_n)
[docs]def get_di_kwrags(struct_size_0=STRUCT_SIZE, struct_size_1=STRUCT_SIZE): """ Creates a dict with the dilation keywords. :param struct_size_0: struct size in axis=0 :param struct_size_1: struct size in axis=1 :return: dictionary with the dilation keywords """ return dict(struct_size_0=struct_size_0, struct_size_1=struct_size_1)
[docs]def get_sumthreshold_kwargs(params): """ Creates the smoothing and dilation kwargs from a params objects. :param params: the params object containing the configuration :return: smoothing and dilation kwargs """ sm_kwargs = get_sm_kwargs(params.sm_kernel_m, params.sm_kernel_n, params.sm_sigma_m, params.sm_sigma_n,) di_kwargs = get_di_kwrags(params.struct_size_0, params.struct_size_1) return sm_kwargs, di_kwargs
[docs]def rm_rfi(ctx): """ Call the main SumThreshold routine. :param ctx: context :return: SumThreshold RFI mask. """ sm_kwars, di_kwargs = get_sumthreshold_kwargs(ctx.params) rfi_mask_vx = get_rfi_mask(ctx.tod_vx, mask=ctx.tod_vx.mask, chi_1 = ctx.params.chi_1, eta_i=ctx.params.eta_i, plotting=False, sm_kwargs=sm_kwars, di_kwargs=di_kwargs) return rfi_mask_vx, rfi_mask_vx