"""Image operations visible to the Execution Framework as Components"""
__all__ = [
"add_image",
"average_image_over_frequency",
"create_w_term_like",
"create_window",
"sub_image",
"polarisation_frame_from_wcs",
"remove_continuum_image",
"show_components",
"show_image",
"smooth_image",
"apply_voltage_pattern_to_image",
]
import copy
import logging
import warnings
import numpy
from astropy.coordinates import SkyCoord
from astropy.wcs import FITSFixedWarning
from astropy.wcs.utils import skycoord_to_pixel
from ska_sdp_datamodels.image.image_model import Image
from ska_sdp_datamodels.science_data_model.polarisation_model import PolarisationFrame
from ska_sdp_func_python.calibration import apply_jones
from ska_sdp_func_python.fourier_transforms import w_beam
from ska_sdp_func_python.image import (
convert_polimage_to_stokes,
convert_stokes_to_polimage,
)
from rascil.processing_components.parameters import get_parameter
warnings.simplefilter("ignore", FITSFixedWarning)
log = logging.getLogger("rascil-logger")
[docs]
def add_image(im1: Image, im2: Image) -> Image:
"""Add two images
:param im1: Image
:param im2: Image
:return: Image
"""
return Image.constructor(
data=im1["pixels"].data + im2["pixels"].data,
polarisation_frame=im1.image_acc.polarisation_frame,
wcs=im1.image_acc.wcs,
)
[docs]
def show_image(
im: Image,
fig=None,
title: str = "",
pol=0,
chan=0,
cm="Greys",
components=None,
vmin=None,
vmax=None,
vscale=1.0,
):
"""Show an Image with coordinates using matplotlib, optionally with components
:param im: Image
:param fig: Matplotlib figure
:param title: String for title of plot
:param pol: Polarisation to show (index)
:param chan: Channel to show (index)
:param components: Optional components to be overlaid
:param vmin: Clip to this minimum
:param vmax: Clip to this maximum
:param vscale: scale max, min by this amount
:return:
"""
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=im.image_acc.wcs.sub([1, 2]))
if len(im["pixels"].data.shape) == 4:
data_array = numpy.real(im["pixels"].data[chan, pol, :, :])
else:
data_array = numpy.real(im["pixels"].data)
if vmax is None:
vmax = vscale * numpy.max(data_array)
if vmin is None:
vmin = vscale * numpy.min(data_array)
cm = ax.imshow(data_array, origin="lower", cmap=cm, vmax=vmax, vmin=vmin)
ax.set_xlabel(im.image_acc.wcs.wcs.ctype[0])
ax.set_ylabel(im.image_acc.wcs.wcs.ctype[1])
ax.set_title(title)
fig.colorbar(cm, orientation="vertical", shrink=0.7)
if components is not None:
for sc in components:
x, y = skycoord_to_pixel(sc.direction, im.image_acc.wcs, 0, "wcs")
ax.plot(x, y, marker="+", color="red")
return fig
[docs]
def show_components(im, comps, npixels=128, fig=None, vmax=None, vmin=None, title=""):
"""Show components against an image
:param im:
:param comps:
:param npixels:
:param fig:
:return:
"""
import matplotlib.pyplot as plt
if vmax is None:
vmax = numpy.max(im["pixels"].data[0, 0, ...])
if vmin is None:
vmin = numpy.min(im["pixels"].data[0, 0, ...])
if not fig:
fig = plt.figure()
plt.clf()
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
for sc in comps:
newim = im.copy(deep=True)
plt.subplot(111, projection=newim.image_acc.wcs.sub([1, 2]))
centre = numpy.round(
skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 1, "wcs")
).astype("int")
newim["pixels"].data = newim["pixels"].data[
:,
:,
(centre[1] - npixels // 2) : (centre[1] + npixels // 2),
(centre[0] - npixels // 2) : (centre[0] + npixels // 2),
]
newim.image_acc.wcs.wcs.crpix[0] -= centre[0] - npixels // 2
newim.image_acc.wcs.wcs.crpix[1] -= centre[1] - npixels // 2
plt.imshow(
newim["pixels"].data[0, 0, ...],
origin="lower",
cmap="Greys",
vmax=vmax,
vmin=vmin,
)
x, y = skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 0, "wcs")
plt.plot(x, y, marker="+", color="red")
plt.title("Name = %s, flux = %s" % (sc.name, sc.flux))
plt.show()
[docs]
def smooth_image(model: Image, width=1.0, normalise=True):
"""Smooth an image with a 2D Gaussian kernel
:param model: Image
:param width: Kernel width in pixels
:param normalise: Normalise kernel peak to unity
"""
assert isinstance(model, Image), model
assert model.image_acc.is_canonical()
from astropy.convolution import convolve_fft
from astropy.convolution.kernels import Gaussian2DKernel
kernel = Gaussian2DKernel(width)
model_type = model["pixels"].data.dtype
cmodel = Image.constructor(
data=numpy.zeros_like(model["pixels"].data),
polarisation_frame=model.image_acc.polarisation_frame,
wcs=model.image_acc.wcs,
clean_beam=model.attrs["clean_beam"],
)
nchan, npol, _, _ = model.image_acc.shape
for pol in range(npol):
for chan in range(nchan):
cmodel["pixels"].data[chan, pol, :, :] = convolve_fft(
model["pixels"].data[chan, pol, :, :],
kernel,
normalize_kernel=False,
allow_huge=True,
)
# The convolve_fft step seems to return an object dtype
cmodel["pixels"].data = cmodel["pixels"].data.astype(model_type)
if normalise and isinstance(kernel, Gaussian2DKernel):
cmodel["pixels"].data *= 2 * numpy.pi * width**2
return cmodel
[docs]
def average_image_over_frequency(im: Image) -> Image:
"""Integrate image across frequency
:return: Integrated image
"""
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
newim_data = numpy.mean(im["pixels"].data, axis=0)[numpy.newaxis, ...]
assert not numpy.isnan(numpy.sum(im["pixels"].data)), "NaNs present in image data"
newim_wcs = copy.deepcopy(im.image_acc.wcs)
newim_wcs.wcs.crval[3] = numpy.average(im.frequency.data)
newim_wcs.wcs.crpix[3] = 1
return Image.constructor(
data=newim_data,
polarisation_frame=im.image_acc.polarisation_frame,
wcs=newim_wcs,
)
[docs]
def remove_continuum_image(im: Image, degree=1, mask=None):
"""Fit and remove continuum visibility in place
Fit a polynomial in frequency of the specified degree where mask is
True and remove it from the image
:param im:
:param degree: 1 is a constant, 2 is a slope, etc.
:param mask: Frequency mask
:return:
"""
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
if mask is not None:
assert numpy.sum(mask) > 2 * degree, "Insufficient channels for fit"
nchan, npol, ny, nx = im["pixels"].data.shape
channels = numpy.arange(nchan)
frequency = im.image_acc.wcs.sub(["spectral"]).wcs_pix2world(channels, 0)[0]
frequency -= frequency[nchan // 2]
frequency /= numpy.max(frequency)
wt = numpy.ones_like(frequency)
if mask is not None:
wt[mask] = 0.0
for pol in range(npol):
for y in range(ny):
for x in range(nx):
fit = numpy.polyfit(
frequency, im["pixels"].data[:, pol, y, x], w=wt, deg=degree
)
prediction = numpy.polyval(fit, frequency)
im["pixels"].data[:, pol, y, x] -= prediction
return im
[docs]
def create_window(template, window_type, **kwargs):
"""Create a window image using one of a number of methods
The window is 1.0 or 0.0
window types:
'quarter': Inner quarter of the image
'no_edge': 'window_edge' pixels around edge set to zero
'threshold': template image pixels < 'window_threshold' absolute
value set to zero
:param template: Template image
:param window_type: 'quarter' | 'no_edge' | 'threshold'
:return: New image containing window
See also
:py:func:`rascil.processing_components.image.deconvolution.deconvolve_cube`
"""
assert isinstance(template, Image), template
assert template.image_acc.is_canonical()
window = Image.constructor(
data=numpy.zeros_like(template["pixels"].data),
polarisation_frame=template.image_acc.polarisation_frame,
wcs=template.image_acc.wcs,
clean_beam=template.attrs["clean_beam"],
)
if window_type == "quarter":
qx = template["pixels"].shape[3] // 4
qy = template["pixels"].shape[2] // 4
window["pixels"].data[..., (qy + 1) : 3 * qy, (qx + 1) : 3 * qx] = 1.0
log.info("create_mask: Cleaning inner quarter of each sky plane")
elif window_type == "no_edge":
edge = get_parameter(kwargs, "window_edge", 16)
nx = template["pixels"].shape[3]
ny = template["pixels"].shape[2]
window["pixels"].data[
..., (edge + 1) : (ny - edge), (edge + 1) : (nx - edge)
] = 1.0
log.info("create_mask: Window omits %d-pixel edge of each sky plane" % (edge))
elif window_type == "threshold":
window_threshold = get_parameter(kwargs, "window_threshold", None)
if window_threshold is None:
window_threshold = 10.0 * numpy.std(template["pixels"].data)
window["pixels"].data[template["pixels"].data >= window_threshold] = 1.0
log.info("create_mask: Window omits all points below %g" % (window_threshold))
elif window_type is None:
log.info("create_mask: Mask covers entire image")
else:
raise ValueError("Window shape %s is not recognized" % window_type)
return window
[docs]
def polarisation_frame_from_wcs(wcs, shape) -> PolarisationFrame:
"""Convert wcs to polarisation_frame
See FITS definition in Table 29 of
https://fits.gsfc.nasa.gov/standard40/fits_standard40draft1.pdf
or subsequent revision
1 I Standard Stokes unpolarized
2 Q Standard Stokes linear
3 U Standard Stokes linear
4 V Standard Stokes circular
−1 RR Right-right circular
−2 LL Left-left circular
−3 RL Right-left cross-circular
−4 LR Left-right cross-circular
−5 XX X parallel linear
−6 YY Y parallel linear
−7 XY XY cross linear
−8 YX YX cross linear
stokesI [1]
stokesIQUV [1,2,3,4]
circular [-1,-2,-3,-4]
linear [-5,-6,-7,-8]
For example::
pol_frame =
polarisation_frame_from_wcs(im.image_acc.wcs, im["pixels"].data.shape)
:param wcs: World Coordinate System
:param shape: Shape corresponding to wcs
:returns: Polarisation_Frame object
"""
# The third axis should be stokes:
polarisation_frame = None
if len(shape) == 2:
polarisation_frame = PolarisationFrame("stokesI")
else:
npol = shape[1]
pol = wcs.sub(["stokes"]).wcs_pix2world(range(npol), 0)[0]
pol = numpy.array(pol, dtype="int")
for key in PolarisationFrame.fits_codes.keys():
keypol = numpy.array(PolarisationFrame.fits_codes[key])
if numpy.array_equal(pol, keypol):
polarisation_frame = PolarisationFrame(key)
return polarisation_frame
if polarisation_frame is None:
raise ValueError("Cannot determine polarisation code")
return polarisation_frame
[docs]
def sub_image(im: Image, shape):
"""Subsection an image to desired shape, cutting equally from all edges
Appropriate for standard 4D image with axes (freq, pol, y, x). Only works in y, x
The wcs crpix is adjusted appropriately.
:param im: Image to be padded
:param shape: Shape in 4 dimensions
:return: Padded image
"""
if im["pixels"].data.shape == shape:
return im
else:
if len(shape) == 2:
shape = (1, 1, shape[0], shape[1])
newwcs = copy.deepcopy(im.image_acc.wcs)
newwcs.wcs.crpix[0] = (
im.image_acc.wcs.wcs.crpix[0]
+ shape[3] // 2
- im["pixels"].data.shape[3] // 2
)
newwcs.wcs.crpix[1] = (
im.image_acc.wcs.wcs.crpix[1]
+ shape[2] // 2
- im["pixels"].data.shape[2] // 2
)
for axis, _ in enumerate(im["pixels"].data.shape):
if shape[axis] > im["pixels"].data.shape[axis]:
raise ValueError(
"Padded shape %s is larger than input shape %s"
% (shape, im["pixels"].data.shape)
)
ystart = im["pixels"].data.shape[2] // 2 - shape[2] // 2
yend = ystart + shape[2]
xstart = im["pixels"].data.shape[3] // 2 - shape[3] // 2
xend = xstart + shape[3]
newdata = im["pixels"][..., ystart:yend, xstart:xend]
return Image.constructor(
data=newdata, polarisation_frame=im.image_acc.polarisation_frame, wcs=newwcs
)
[docs]
def create_w_term_like(
im: Image, w, phasecentre=None, remove_shift=False, dopol=False
) -> Image:
"""Create an image with a w term phase term in it:
.. math::
I(l,m) = e^{-2 \\pi j (w(\\sqrt{1-l^2-m^2}-1)}
The phasecentre is used as the delay centre for the w term (i.e. where n==0)
:param im: template image
:param phasecentre: SkyCoord definition of phasecentre
:param w: w value to evaluate
:param remove_shift:
:param dopol: Do screen in polarisation?
:return: Image
"""
fim_shape = list(im["pixels"].data.shape)
if not dopol:
fim_shape[1] = 1
wcs = im.image_acc.wcs
fim_array = numpy.zeros(fim_shape, dtype="complex")
cellsize = abs(wcs.wcs.cdelt[0]) * numpy.pi / 180.0
npixel = fim_shape[-1]
if phasecentre is SkyCoord:
wcentre = phasecentre.to_pixel(wcs, origin=0)
else:
wcentre = [wcs.wcs.crpix[0] - 1.0, wcs.wcs.crpix[1] - 1.0]
fim_array[...] = w_beam(
npixel,
npixel * cellsize,
w=w,
cx=wcentre[0],
cy=wcentre[1],
remove_shift=remove_shift,
)[numpy.newaxis, numpy.newaxis, ...]
fim = Image.constructor(
data=fim_array, polarisation_frame=im.image_acc.polarisation_frame, wcs=wcs
)
fov = npixel * cellsize
fresnel = numpy.abs(w) * (0.5 * fov) ** 2
log.debug(
"create_w_term_image: For w = %.1f, field of view = %.6f, Fresnel number = %.2f"
% (w, fov, fresnel)
)
return fim
def rotate_image(im, angle=0.0, order=5):
"""Rotate an image in x, y axes
:param im: Image
:param angle: Angle in radians
:param order: Order of interpolation (0-5)
:return:
"""
from scipy.ndimage.interpolation import rotate
newim = im.copy(deep=True)
if newim["pixels"].data.dtype == "complex":
newim["pixels"].data = rotate(
im["pixels"].data.real,
angle=numpy.rad2deg(angle),
axes=(-2, -1),
order=order,
) + 1j * rotate(
im["pixels"].data.imag,
angle=numpy.rad2deg(angle),
axes=(-2, -1),
order=order,
)
else:
newim["pixels"].data = rotate(
im["pixels"].data, angle=numpy.rad2deg(angle), axes=(-2, -1), order=order
)
return newim
[docs]
def apply_voltage_pattern_to_image(
im: Image, vp: Image, inverse=False, min_det=1e-1, **kwargs
) -> Image:
"""Apply a voltage pattern to an image
For each pixel, the application is as follows:
I_{corrected}(l,m) = vp(l,m) I(l,m) jones(j,m).H
:param im: Image to have jones applied
:param vp: Jones image to be applied
:param inverse: Apply the inverse (default=False)
:param min_det: Minimum determinant to correct
:return: new Image with Jones applied
"""
newim = Image.constructor(
data=numpy.zeros_like(im["pixels"].data),
polarisation_frame=im.image_acc.polarisation_frame,
wcs=im.image_acc.wcs,
clean_beam=im.attrs["clean_beam"],
)
if inverse:
log.debug("apply_gaintable: Apply inverse voltage pattern image")
else:
log.debug("apply_gaintable: Apply voltage pattern image")
is_scalar = vp.image_acc.shape[1] == 1
nchan, npol, ny, nx = im["pixels"].data.shape
assert im["pixels"].data.shape == vp["pixels"].data.shape
if is_scalar:
log.debug("apply_voltage_pattern_to_image: Scalar voltage pattern")
if inverse:
for chan in range(nchan):
pb = (
vp["pixels"].data[chan, 0, ...]
* numpy.conjugate(vp["pixels"].data[chan, 0, ...])
).real
newim["pixels"].data[chan, 0, ...] *= pb
else:
for chan in range(nchan):
pb = (
vp["pixels"].data[chan, 0, ...]
* numpy.conjugate(vp["pixels"].data[chan, 0, ...])
).real
mask = pb > 0.0
newim["pixels"].data[chan, 0, ...][mask] /= pb[mask]
else:
log.debug("apply_voltage_pattern_to_image: Full Jones voltage pattern")
polim = convert_stokes_to_polimage(im, vp.image_acc.polarisation_frame)
assert npol == 4
im_t = numpy.transpose(polim["pixels"].data, (0, 2, 3, 1)).reshape(
[nchan, ny, nx, 2, 2]
)
vp_t = numpy.transpose(vp["pixels"].data, (0, 2, 3, 1)).reshape(
[nchan, ny, nx, 2, 2]
)
newim_t = numpy.zeros([nchan, ny, nx, 2, 2], dtype="complex")
for chan in range(nchan):
for y in range(ny):
for x in range(nx):
newim_t[chan, y, x] = apply_jones(
vp_t[chan, y, x], im_t[chan, y, x], inverse, min_det=min_det
)
newim = Image.constructor(
data=newim_t.reshape([nchan, ny, nx, 4]).transpose((0, 3, 1, 2)),
polarisation_frame=vp.image_acc.polarisation_frame,
wcs=im.image_acc.wcs,
)
newim = convert_polimage_to_stokes(newim)
return newim