Source code for rascil.processing_components.image.operations

"""Image operations visible to the Execution Framework as Components"""

__all__ = [
    "add_image",
    "average_image_over_frequency",
    "create_w_term_like",
    "create_window",
    "sub_image",
    "polarisation_frame_from_wcs",
    "remove_continuum_image",
    "show_components",
    "show_image",
    "smooth_image",
    "apply_voltage_pattern_to_image",
]

import copy
import logging
import warnings

import numpy
from astropy.coordinates import SkyCoord
from astropy.wcs import FITSFixedWarning
from astropy.wcs.utils import skycoord_to_pixel
from ska_sdp_datamodels.image.image_model import Image
from ska_sdp_datamodels.science_data_model.polarisation_model import PolarisationFrame
from ska_sdp_func_python.calibration import apply_jones
from ska_sdp_func_python.fourier_transforms import w_beam
from ska_sdp_func_python.image import (
    convert_polimage_to_stokes,
    convert_stokes_to_polimage,
)

from rascil.processing_components.parameters import get_parameter

warnings.simplefilter("ignore", FITSFixedWarning)
log = logging.getLogger("rascil-logger")


[docs] def add_image(im1: Image, im2: Image) -> Image: """Add two images :param im1: Image :param im2: Image :return: Image """ return Image.constructor( data=im1["pixels"].data + im2["pixels"].data, polarisation_frame=im1.image_acc.polarisation_frame, wcs=im1.image_acc.wcs, )
[docs] def show_image( im: Image, fig=None, title: str = "", pol=0, chan=0, cm="Greys", components=None, vmin=None, vmax=None, vscale=1.0, ): """Show an Image with coordinates using matplotlib, optionally with components :param im: Image :param fig: Matplotlib figure :param title: String for title of plot :param pol: Polarisation to show (index) :param chan: Channel to show (index) :param components: Optional components to be overlaid :param vmin: Clip to this minimum :param vmax: Clip to this maximum :param vscale: scale max, min by this amount :return: """ import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection=im.image_acc.wcs.sub([1, 2])) if len(im["pixels"].data.shape) == 4: data_array = numpy.real(im["pixels"].data[chan, pol, :, :]) else: data_array = numpy.real(im["pixels"].data) if vmax is None: vmax = vscale * numpy.max(data_array) if vmin is None: vmin = vscale * numpy.min(data_array) cm = ax.imshow(data_array, origin="lower", cmap=cm, vmax=vmax, vmin=vmin) ax.set_xlabel(im.image_acc.wcs.wcs.ctype[0]) ax.set_ylabel(im.image_acc.wcs.wcs.ctype[1]) ax.set_title(title) fig.colorbar(cm, orientation="vertical", shrink=0.7) if components is not None: for sc in components: x, y = skycoord_to_pixel(sc.direction, im.image_acc.wcs, 0, "wcs") ax.plot(x, y, marker="+", color="red") return fig
[docs] def show_components(im, comps, npixels=128, fig=None, vmax=None, vmin=None, title=""): """Show components against an image :param im: :param comps: :param npixels: :param fig: :return: """ import matplotlib.pyplot as plt if vmax is None: vmax = numpy.max(im["pixels"].data[0, 0, ...]) if vmin is None: vmin = numpy.min(im["pixels"].data[0, 0, ...]) if not fig: fig = plt.figure() plt.clf() assert isinstance(im, Image), im assert im.image_acc.is_canonical() for sc in comps: newim = im.copy(deep=True) plt.subplot(111, projection=newim.image_acc.wcs.sub([1, 2])) centre = numpy.round( skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 1, "wcs") ).astype("int") newim["pixels"].data = newim["pixels"].data[ :, :, (centre[1] - npixels // 2) : (centre[1] + npixels // 2), (centre[0] - npixels // 2) : (centre[0] + npixels // 2), ] newim.image_acc.wcs.wcs.crpix[0] -= centre[0] - npixels // 2 newim.image_acc.wcs.wcs.crpix[1] -= centre[1] - npixels // 2 plt.imshow( newim["pixels"].data[0, 0, ...], origin="lower", cmap="Greys", vmax=vmax, vmin=vmin, ) x, y = skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 0, "wcs") plt.plot(x, y, marker="+", color="red") plt.title("Name = %s, flux = %s" % (sc.name, sc.flux)) plt.show()
[docs] def smooth_image(model: Image, width=1.0, normalise=True): """Smooth an image with a 2D Gaussian kernel :param model: Image :param width: Kernel width in pixels :param normalise: Normalise kernel peak to unity """ assert isinstance(model, Image), model assert model.image_acc.is_canonical() from astropy.convolution import convolve_fft from astropy.convolution.kernels import Gaussian2DKernel kernel = Gaussian2DKernel(width) model_type = model["pixels"].data.dtype cmodel = Image.constructor( data=numpy.zeros_like(model["pixels"].data), polarisation_frame=model.image_acc.polarisation_frame, wcs=model.image_acc.wcs, clean_beam=model.attrs["clean_beam"], ) nchan, npol, _, _ = model.image_acc.shape for pol in range(npol): for chan in range(nchan): cmodel["pixels"].data[chan, pol, :, :] = convolve_fft( model["pixels"].data[chan, pol, :, :], kernel, normalize_kernel=False, allow_huge=True, ) # The convolve_fft step seems to return an object dtype cmodel["pixels"].data = cmodel["pixels"].data.astype(model_type) if normalise and isinstance(kernel, Gaussian2DKernel): cmodel["pixels"].data *= 2 * numpy.pi * width**2 return cmodel
[docs] def average_image_over_frequency(im: Image) -> Image: """Integrate image across frequency :return: Integrated image """ assert isinstance(im, Image), im assert im.image_acc.is_canonical() newim_data = numpy.mean(im["pixels"].data, axis=0)[numpy.newaxis, ...] assert not numpy.isnan(numpy.sum(im["pixels"].data)), "NaNs present in image data" newim_wcs = copy.deepcopy(im.image_acc.wcs) newim_wcs.wcs.crval[3] = numpy.average(im.frequency.data) newim_wcs.wcs.crpix[3] = 1 return Image.constructor( data=newim_data, polarisation_frame=im.image_acc.polarisation_frame, wcs=newim_wcs, )
[docs] def remove_continuum_image(im: Image, degree=1, mask=None): """Fit and remove continuum visibility in place Fit a polynomial in frequency of the specified degree where mask is True and remove it from the image :param im: :param degree: 1 is a constant, 2 is a slope, etc. :param mask: Frequency mask :return: """ assert isinstance(im, Image), im assert im.image_acc.is_canonical() if mask is not None: assert numpy.sum(mask) > 2 * degree, "Insufficient channels for fit" nchan, npol, ny, nx = im["pixels"].data.shape channels = numpy.arange(nchan) frequency = im.image_acc.wcs.sub(["spectral"]).wcs_pix2world(channels, 0)[0] frequency -= frequency[nchan // 2] frequency /= numpy.max(frequency) wt = numpy.ones_like(frequency) if mask is not None: wt[mask] = 0.0 for pol in range(npol): for y in range(ny): for x in range(nx): fit = numpy.polyfit( frequency, im["pixels"].data[:, pol, y, x], w=wt, deg=degree ) prediction = numpy.polyval(fit, frequency) im["pixels"].data[:, pol, y, x] -= prediction return im
[docs] def create_window(template, window_type, **kwargs): """Create a window image using one of a number of methods The window is 1.0 or 0.0 window types: 'quarter': Inner quarter of the image 'no_edge': 'window_edge' pixels around edge set to zero 'threshold': template image pixels < 'window_threshold' absolute value set to zero :param template: Template image :param window_type: 'quarter' | 'no_edge' | 'threshold' :return: New image containing window See also :py:func:`rascil.processing_components.image.deconvolution.deconvolve_cube` """ assert isinstance(template, Image), template assert template.image_acc.is_canonical() window = Image.constructor( data=numpy.zeros_like(template["pixels"].data), polarisation_frame=template.image_acc.polarisation_frame, wcs=template.image_acc.wcs, clean_beam=template.attrs["clean_beam"], ) if window_type == "quarter": qx = template["pixels"].shape[3] // 4 qy = template["pixels"].shape[2] // 4 window["pixels"].data[..., (qy + 1) : 3 * qy, (qx + 1) : 3 * qx] = 1.0 log.info("create_mask: Cleaning inner quarter of each sky plane") elif window_type == "no_edge": edge = get_parameter(kwargs, "window_edge", 16) nx = template["pixels"].shape[3] ny = template["pixels"].shape[2] window["pixels"].data[ ..., (edge + 1) : (ny - edge), (edge + 1) : (nx - edge) ] = 1.0 log.info("create_mask: Window omits %d-pixel edge of each sky plane" % (edge)) elif window_type == "threshold": window_threshold = get_parameter(kwargs, "window_threshold", None) if window_threshold is None: window_threshold = 10.0 * numpy.std(template["pixels"].data) window["pixels"].data[template["pixels"].data >= window_threshold] = 1.0 log.info("create_mask: Window omits all points below %g" % (window_threshold)) elif window_type is None: log.info("create_mask: Mask covers entire image") else: raise ValueError("Window shape %s is not recognized" % window_type) return window
[docs] def polarisation_frame_from_wcs(wcs, shape) -> PolarisationFrame: """Convert wcs to polarisation_frame See FITS definition in Table 29 of https://fits.gsfc.nasa.gov/standard40/fits_standard40draft1.pdf or subsequent revision 1 I Standard Stokes unpolarized 2 Q Standard Stokes linear 3 U Standard Stokes linear 4 V Standard Stokes circular −1 RR Right-right circular −2 LL Left-left circular −3 RL Right-left cross-circular −4 LR Left-right cross-circular −5 XX X parallel linear −6 YY Y parallel linear −7 XY XY cross linear −8 YX YX cross linear stokesI [1] stokesIQUV [1,2,3,4] circular [-1,-2,-3,-4] linear [-5,-6,-7,-8] For example:: pol_frame = polarisation_frame_from_wcs(im.image_acc.wcs, im["pixels"].data.shape) :param wcs: World Coordinate System :param shape: Shape corresponding to wcs :returns: Polarisation_Frame object """ # The third axis should be stokes: polarisation_frame = None if len(shape) == 2: polarisation_frame = PolarisationFrame("stokesI") else: npol = shape[1] pol = wcs.sub(["stokes"]).wcs_pix2world(range(npol), 0)[0] pol = numpy.array(pol, dtype="int") for key in PolarisationFrame.fits_codes.keys(): keypol = numpy.array(PolarisationFrame.fits_codes[key]) if numpy.array_equal(pol, keypol): polarisation_frame = PolarisationFrame(key) return polarisation_frame if polarisation_frame is None: raise ValueError("Cannot determine polarisation code") return polarisation_frame
[docs] def sub_image(im: Image, shape): """Subsection an image to desired shape, cutting equally from all edges Appropriate for standard 4D image with axes (freq, pol, y, x). Only works in y, x The wcs crpix is adjusted appropriately. :param im: Image to be padded :param shape: Shape in 4 dimensions :return: Padded image """ if im["pixels"].data.shape == shape: return im else: if len(shape) == 2: shape = (1, 1, shape[0], shape[1]) newwcs = copy.deepcopy(im.image_acc.wcs) newwcs.wcs.crpix[0] = ( im.image_acc.wcs.wcs.crpix[0] + shape[3] // 2 - im["pixels"].data.shape[3] // 2 ) newwcs.wcs.crpix[1] = ( im.image_acc.wcs.wcs.crpix[1] + shape[2] // 2 - im["pixels"].data.shape[2] // 2 ) for axis, _ in enumerate(im["pixels"].data.shape): if shape[axis] > im["pixels"].data.shape[axis]: raise ValueError( "Padded shape %s is larger than input shape %s" % (shape, im["pixels"].data.shape) ) ystart = im["pixels"].data.shape[2] // 2 - shape[2] // 2 yend = ystart + shape[2] xstart = im["pixels"].data.shape[3] // 2 - shape[3] // 2 xend = xstart + shape[3] newdata = im["pixels"][..., ystart:yend, xstart:xend] return Image.constructor( data=newdata, polarisation_frame=im.image_acc.polarisation_frame, wcs=newwcs )
[docs] def create_w_term_like( im: Image, w, phasecentre=None, remove_shift=False, dopol=False ) -> Image: """Create an image with a w term phase term in it: .. math:: I(l,m) = e^{-2 \\pi j (w(\\sqrt{1-l^2-m^2}-1)} The phasecentre is used as the delay centre for the w term (i.e. where n==0) :param im: template image :param phasecentre: SkyCoord definition of phasecentre :param w: w value to evaluate :param remove_shift: :param dopol: Do screen in polarisation? :return: Image """ fim_shape = list(im["pixels"].data.shape) if not dopol: fim_shape[1] = 1 wcs = im.image_acc.wcs fim_array = numpy.zeros(fim_shape, dtype="complex") cellsize = abs(wcs.wcs.cdelt[0]) * numpy.pi / 180.0 npixel = fim_shape[-1] if phasecentre is SkyCoord: wcentre = phasecentre.to_pixel(wcs, origin=0) else: wcentre = [wcs.wcs.crpix[0] - 1.0, wcs.wcs.crpix[1] - 1.0] fim_array[...] = w_beam( npixel, npixel * cellsize, w=w, cx=wcentre[0], cy=wcentre[1], remove_shift=remove_shift, )[numpy.newaxis, numpy.newaxis, ...] fim = Image.constructor( data=fim_array, polarisation_frame=im.image_acc.polarisation_frame, wcs=wcs ) fov = npixel * cellsize fresnel = numpy.abs(w) * (0.5 * fov) ** 2 log.debug( "create_w_term_image: For w = %.1f, field of view = %.6f, Fresnel number = %.2f" % (w, fov, fresnel) ) return fim
def rotate_image(im, angle=0.0, order=5): """Rotate an image in x, y axes :param im: Image :param angle: Angle in radians :param order: Order of interpolation (0-5) :return: """ from scipy.ndimage.interpolation import rotate newim = im.copy(deep=True) if newim["pixels"].data.dtype == "complex": newim["pixels"].data = rotate( im["pixels"].data.real, angle=numpy.rad2deg(angle), axes=(-2, -1), order=order, ) + 1j * rotate( im["pixels"].data.imag, angle=numpy.rad2deg(angle), axes=(-2, -1), order=order, ) else: newim["pixels"].data = rotate( im["pixels"].data, angle=numpy.rad2deg(angle), axes=(-2, -1), order=order ) return newim
[docs] def apply_voltage_pattern_to_image( im: Image, vp: Image, inverse=False, min_det=1e-1, **kwargs ) -> Image: """Apply a voltage pattern to an image For each pixel, the application is as follows: I_{corrected}(l,m) = vp(l,m) I(l,m) jones(j,m).H :param im: Image to have jones applied :param vp: Jones image to be applied :param inverse: Apply the inverse (default=False) :param min_det: Minimum determinant to correct :return: new Image with Jones applied """ newim = Image.constructor( data=numpy.zeros_like(im["pixels"].data), polarisation_frame=im.image_acc.polarisation_frame, wcs=im.image_acc.wcs, clean_beam=im.attrs["clean_beam"], ) if inverse: log.debug("apply_gaintable: Apply inverse voltage pattern image") else: log.debug("apply_gaintable: Apply voltage pattern image") is_scalar = vp.image_acc.shape[1] == 1 nchan, npol, ny, nx = im["pixels"].data.shape assert im["pixels"].data.shape == vp["pixels"].data.shape if is_scalar: log.debug("apply_voltage_pattern_to_image: Scalar voltage pattern") if inverse: for chan in range(nchan): pb = ( vp["pixels"].data[chan, 0, ...] * numpy.conjugate(vp["pixels"].data[chan, 0, ...]) ).real newim["pixels"].data[chan, 0, ...] *= pb else: for chan in range(nchan): pb = ( vp["pixels"].data[chan, 0, ...] * numpy.conjugate(vp["pixels"].data[chan, 0, ...]) ).real mask = pb > 0.0 newim["pixels"].data[chan, 0, ...][mask] /= pb[mask] else: log.debug("apply_voltage_pattern_to_image: Full Jones voltage pattern") polim = convert_stokes_to_polimage(im, vp.image_acc.polarisation_frame) assert npol == 4 im_t = numpy.transpose(polim["pixels"].data, (0, 2, 3, 1)).reshape( [nchan, ny, nx, 2, 2] ) vp_t = numpy.transpose(vp["pixels"].data, (0, 2, 3, 1)).reshape( [nchan, ny, nx, 2, 2] ) newim_t = numpy.zeros([nchan, ny, nx, 2, 2], dtype="complex") for chan in range(nchan): for y in range(ny): for x in range(nx): newim_t[chan, y, x] = apply_jones( vp_t[chan, y, x], im_t[chan, y, x], inverse, min_det=min_det ) newim = Image.constructor( data=newim_t.reshape([nchan, ny, nx, 4]).transpose((0, 3, 1, 2)), polarisation_frame=vp.image_acc.polarisation_frame, wcs=im.image_acc.wcs, ) newim = convert_polimage_to_stokes(newim) return newim