import sys
import jax.numpy as jnp
from jax import jit, random
from jax.lax import scan
from scipy.special import jv
from functools import partial
c = 2.99792458e8
[docs]
def rfi_vis(app_amplitude, c_distances, freqs, a1, a2):
"""
Calculate visibilities from distances to rfi sources.
Parameters
----------
app_amplitude: array_like (n_src, n_time, n_ant, n_freq)
Apparent amplitude at the antennas.
c_distances: array_like (n_src, n_time, n_ant)
The phase corrected distances between the rfi sources and the antennas in metres.
freqs: array_like (n_freq,)
Frequencies in Hz.
a1: array_like (n_bl,)
Antenna 1 indexes, between 0 and n_ant-1.
a2: array_like (n_bl,)
Antenna 2 indexes, between 0 and n_ant-1.
Returns
-------
vis: array_like (n_time, n_bl, n_freq)
The visibilities.
"""
n_src = app_amplitude.shape[0]
vis = _rfi_vis(app_amplitude[0, None], c_distances[0, None], freqs, a1, a2)
# This is a scan over the sources, but we can't use scan it unless we jit decorate this function
def _add_vis(vis, i):
return (
vis + _rfi_vis(app_amplitude[i, None], c_distances[i, None], freqs, a1, a2),
i,
)
return scan(_add_vis, vis, jnp.arange(1, n_src))[0]
# return _rfi_vis(app_amplitude, c_distances, freqs, a1, a2)
[docs]
def astro_vis(sources, uvw, lmn, freqs):
"""
Calculate visibilities from a set of point sources using DFT.
Parameters
----------
sources: array_like (n_src, n_time, n_freq)
Array of point source intensities in Jy.
uvw: array_like (ntime, n_bl, 3)
(u,v,w) coordinates of each baseline.
lmn: array_like (n_src, 3)
(l,m,n) coordinate of each source.
freqs: array_like (n_freq,)
Frequencies in Hz.
Returns
-------
vis: array_like (n_time, n_bl, n_freq)
Visibilities of the given set of sources and baselines.
"""
n_src = sources.shape[0]
vis = _astro_vis(sources[0, None], uvw, lmn[0, None], freqs)
# This is a scan over the sources, but we can't use scan it unless we jit decorate this function
@jit
def _add_vis(vis, i):
return vis + _astro_vis(sources[i, None], uvw, lmn[i, None], freqs), i
return scan(_add_vis, vis, jnp.arange(1, n_src))[0]
[docs]
def astro_vis_gauss(sources, major, minor, pos_angle, uvw, lmn, freqs):
"""
Calculate visibilities from a set of point sources using DFT.
Parameters
----------
sources: array_like (n_src, n_time, n_freq)
Array of point source intensities in Jy.
shapes: array_like (n_src,)
Array of standard deviations of the gaussian shape sources. These are
assumed to be circular gaussians for now.
uvw: array_like (ntime, n_bl, 3)
(u,v,w) coordinates of each baseline.
lmn: array_like (n_src, 3)
(l,m,n) coordinate of each source.
freqs: array_like (n_freq,)
Frequencies in Hz.
Returns
-------
vis: array_like (n_time, n_bl, n_freq)
Visibilities of the given set of sources and baselines.
"""
n_src = sources.shape[0]
vis = _astro_vis_gauss(
sources[0, None],
major[0, None],
minor[0, None],
pos_angle[0, None],
uvw,
lmn[0, None],
freqs,
)
# This is a scan over the sources, but we can't use scan it unless we jit decorate this function
def _add_vis(vis, i):
return (
vis
+ _astro_vis_gauss(
sources[i, None],
major[i, None],
minor[i, None],
pos_angle[i, None],
uvw,
lmn[i, None],
freqs,
),
i,
)
return scan(_add_vis, vis, jnp.arange(1, n_src))[0]
[docs]
def astro_vis_exp(sources, shapes, uvw, lmn, freqs):
"""
Calculate visibilities from a set of point sources using DFT.
Parameters
----------
sources: array_like (n_src, n_time, n_freq)
Array of point source intensities in Jy.
shapes: array_like (n_src,)
Array of standard deviations of the gaussian shape sources. These are
assumed to be circular gaussians for now.
uvw: array_like (ntime, n_bl, 3)
(u,v,w) coordinates of each baseline.
lmn: array_like (n_src, 3)
(l,m,n) coordinate of each source.
freqs: array_like (n_freq,)
Frequencies in Hz.
Returns
-------
vis: array_like (n_time, n_bl, n_freq)
Visibilities of the given set of sources and baselines.
"""
n_src = sources.shape[0]
vis = _astro_vis_exp(sources[0, None], shapes[0, None], uvw, lmn[0, None], freqs)
# This is a scan over the sources, but we can't use scan it unless we jit decorate this function
def _add_vis(vis, i):
return (
vis
+ _astro_vis_exp(
sources[i, None], shapes[i, None], uvw, lmn[i, None], freqs
),
i,
)
return scan(_add_vis, vis, jnp.arange(1, n_src))[0]
[docs]
def ants_to_bl(G, a1, a2):
"""
Calculate the complex gains for each baseline given the per antenna gains.
Parameters
----------
G: array_like (n_time, n_ant, n_freq)
Complex gains at each antenna over time.
a1: array_like (n_bl,)
Antenna 1 indexes, between 0 and n_ant-1.
a2: array_like (n_bl,)
Antenna 2 indexes, between 0 and n_ant-1.
Returns
-------
G_bl: array_like (n_time, n_bl, n_freq)
Complex gains on each baseline over time.
"""
return _ants_to_bl(G, a1, a2)
[docs]
def minus_two_pi_over_lamda(freqs):
"""Calculate -2pi/lambda for each frequency.
Args:
freqs (jnp.ndarray): Frequencies in Hz. (n_freq,)
Returns:
jnp.ndarray: -2pi/lambda for each frequency. (n_freq,)
"""
return -2.0 * jnp.pi * freqs / c
[docs]
def amp_to_intensity(amps, a1, a2):
"""Calculate intensity on a baseline ffrom the amplitudes at each antenna.
Args:
amps (jnp.ndarray): Amplitudes at the antennas. (n_src, n_time, n_int, n_ant, n_freq)
a1 (jnp.ndarray): Antenna 1 indexes, between 0 and n_ant-1. (n_bl,)
a2 (jnp.ndarray): Antenna 2 indexes, between 0 and n_ant-1. (n_bl,)
Returns:
jnp.ndarray: Intensity on baselines.
"""
return amps[:, :, :, a1] * jnp.conjugate(amps[:, :, :, a2])
[docs]
def phase_from_distances(distances, a1, a2, freqs):
"""Calculate phase differences between antennas from distances.
Args:
distances (jnp.ndarray): Distances to antennas. (n_src, n_time, n_int, n_ant)
a1 (jnp.ndarray): Antenna 1 indexes, between 0 and n_ant-1. (n_bl,)
a2 (jnp.ndarray): Antenna 2 indexes, between 0 and n_ant-1. (n_bl,)
freqs (jnp.ndarray): Frequencies in Hz. (n_freq,)
Returns:
jnp.ndarray: Phases on baselines.
"""
# Create array of shape (n_src, n_time, n_bl, n_freq)
freqs = freqs[None, None, None, None, :]
distances = distances[:, :, :, :, None]
phases = minus_two_pi_over_lamda(freqs) * (
distances[:, :, :, a1, :] - distances[:, :, :, a2, :]
)
return phases
def _rfi_vis(app_amplitude, c_distances, freqs, a1, a2):
# Create array of shape (n_src, n_time, n_bl, n_freq), then sum over n_src
app_amplitude = jnp.asarray(app_amplitude)
c_distances = jnp.asarray(c_distances)
freqs = jnp.asarray(freqs)
a1 = jnp.asarray(a1)
a2 = jnp.asarray(a2)
phase = phase_from_distances(c_distances, a1, a2, freqs)
intensity = amp_to_intensity(app_amplitude, a1, a2)
vis = jnp.sum(intensity * jnp.exp(1.0j * phase), axis=0)
vis_avg = jnp.mean(vis, axis=1)
return vis_avg
def _astro_vis(sources, uvw, lmn, freqs):
# Create array of shape (n_src, n_time, n_bl, n_freq), then sum over n_src
sources = jnp.asarray(sources[:, :, None, :]) # (n_src, 1, 1, n_freq)
freqs = jnp.asarray(freqs[None, None, None, :]) # (1, 1, 1, n_freq)
uvw = jnp.asarray(uvw[None, :, :, None, :]) # (1, n_time, n_bl, 1, 3)
lmn = jnp.asarray(lmn[:, None, None, None, :]) # (n_src, 1, 1, 1, 3)
s0 = jnp.array([0, 0, 1])[None, None, None, None, :] # (1, 1, 1, 1, 3)
phase = minus_two_pi_over_lamda(freqs) * jnp.sum(uvw * (lmn - s0), axis=-1)
vis = jnp.sum(sources * jnp.exp(-1.0j * phase), axis=0)
return vis
[docs]
def gauss(uvw, shapes, freqs):
uv_mag = jnp.linalg.norm(uvw[..., :-1], axis=-1) / (c / freqs)
sigmas = shapes / (2.0 * jnp.sqrt(2.0 * jnp.log(2)))
sigmas_uv = 1.0 / (2.0 * jnp.pi * sigmas)
return jnp.exp(-((uv_mag / sigmas_uv) ** 2))
[docs]
def source_to_abc(major, minor, pa):
"""Calculate the coefficients of the quadratic for a Gaussian source."""
sigma_factor = 2 * jnp.sqrt(2 * jnp.log(2))
sigma_x = jnp.deg2rad(minor / 3600) / sigma_factor
sigma_y = jnp.deg2rad(major / 3600) / sigma_factor
theta = jnp.deg2rad(pa)
a = jnp.cos(theta) ** 2 / (2 * sigma_x**2) + jnp.sin(theta) ** 2 / (2 * sigma_y**2)
b = jnp.sin(2 * theta) / (4 * sigma_x**2) - jnp.sin(2 * theta) / (4 * sigma_y**2)
c = jnp.sin(theta) ** 2 / (2 * sigma_x**2) + jnp.cos(theta) ** 2 / (2 * sigma_y**2)
return a, b, c
[docs]
def gauss_uv(uvw, major, minor, pos_a, freqs):
cc = 2.99792458e8
lamda = cc / freqs
u = uvw[..., 0] / lamda
v = uvw[..., 1] / lamda
a, b, c = source_to_abc(major, minor, pos_a)
det = (a * c - b**2) / (4 * jnp.pi**2)
return jnp.exp(-(c * u**2 - 2 * b * u * v + a * v**2) / (4 * det))
[docs]
def gauss_lm(l, m, a, b, c):
return jnp.exp(-(a * l**2 + 2 * b * l * m + c * m**2))
def _astro_vis_gauss(sources, major, minor, pos_angle, uvw, lmn, freqs):
# Create array of shape (n_src, n_time, n_bl, n_freq), then sum over n_src
sources = jnp.asarray(sources[:, :, None, :]) # (n_src, n_time, 1, n_freq)
major = jnp.asarray(major[:, None, None, None]) # (n_src, 1, 1, 1)
minor = jnp.asarray(minor[:, None, None, None]) # (n_src, 1, 1, 1)
pos_angle = jnp.asarray(pos_angle[:, None, None, None]) # (n_src, 1, 1, 1)
freqs = jnp.asarray(freqs[None, None, None, :]) # (1, 1, 1, n_freq)
uvw = jnp.asarray(uvw[None, :, :, None, :]) # (1, n_time, n_bl, 1, 3)
lmn = jnp.asarray(lmn[:, None, None, None, :]) # (n_src, 1, 1, 1, 3)
s0 = jnp.array([0, 0, 1])[None, None, None, None, :] # (1, 1, 1, 1, 3)
phase = minus_two_pi_over_lamda(freqs) * jnp.sum(uvw * (lmn - s0), axis=-1)
uv_filter = gauss_uv(uvw, major, minor, pos_angle, freqs)
vis = jnp.sum(uv_filter * sources * jnp.exp(-1.0j * phase), axis=0)
return vis
[docs]
def exp_uv(uvw, shapes, freqs):
U = jnp.linalg.norm(uvw[..., :-1], axis=-1) / (c / freqs)
return 1.0 / (1.0 + (2 * jnp.pi * shapes * U) ** 2) ** 1.5
def _astro_vis_exp(sources, shapes, uvw, lmn, freqs):
# Create array of shape (n_src, n_time, n_bl, n_freq), then sum over n_src
sources = jnp.asarray(sources[:, :, None, :]) # (n_src, n_time, 1, n_freq)
shapes = jnp.asarray(shapes[:, None, None, None]) # (n_src, 1, 1, 1)
freqs = jnp.asarray(freqs[None, None, None, :]) # (1, 1, 1, n_freq)
uvw = jnp.asarray(uvw[None, :, :, None, :]) # (1, n_time, n_bl, 1, 3)
lmn = jnp.asarray(lmn[:, None, None, None, :]) # (n_src, 1, 1, 1, 3)
s0 = jnp.array([0, 0, 1])[None, None, None, None, :] # (1, 1, 1, 1, 3)
phase = minus_two_pi_over_lamda(freqs) * jnp.sum(uvw * (lmn - s0), axis=-1)
vis = jnp.sum(exp_uv(uvw, shapes, freqs) * sources * jnp.exp(-1.0j * phase), axis=0)
return vis
def _ants_to_bl(G, a1, a2):
G_bl = G[:, a1, :] * jnp.conjugate(G[:, a2, :])
return G_bl
[docs]
def airy_beam(theta: jnp.ndarray, freqs: jnp.ndarray, dish_d: float):
"""
Calculate the primary beam voltage at a given angular distance from the
pointing direction. The beam intensity model is the Airy disk as
defined by the dish diameter. This is the same a the CASA default.
Parameters
----------
theta: (n_src, n_time, n_ant)
The angular separation (in degrees) between the pointing direction and the
source.
freqs: (n_freq,)
The frequencies at which to calculate the beam in Hz.
dish_d: float
The diameter of the dish in meters.
Returns
-------
E: ndarray (n_src, n_time, n_ant, n_freq)
The beam voltage at each frequency.
"""
theta = jnp.asarray(theta[:, :, :, None])
freqs = jnp.asarray(freqs)
dish_d = jnp.asarray(dish_d).flatten()[0] # type: ignore
# mask = jnp.where(theta > 90.0, 0, 1)
theta = jnp.deg2rad(theta)
x = jnp.where(
theta == 0.0,
sys.float_info.epsilon,
jnp.pi * freqs[None, None, None, :] * dish_d * jnp.sin(theta) / c,
)
return 2 * jv(1, x) / x
# return (2 * jv(1, x) / x) * mask
[docs]
def Pv_to_Sv(Pv: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray:
"""
Convert emission power to received intensity in Jy. Assumes constant
power across the bandwidth.
Parameters
----------
Pv: ndarray (n_src, n_time, n_freq)
Specific emission power in W/Hz.
d: ndarray (n_src, n_time, n_ant)
Distances from source to receiving antennas in m.
Returns
-------
Sv: ndarray (n_src, n_time, n_ant, n_freq)
Spectral flux density at the receiving antennas in Jy.
"""
Pv = jnp.asarray(Pv)
d = jnp.asarray(d)
return Pv[:, :, None, :] / (4 * jnp.pi * d[:, :, :, None] ** 2) * 1e26
[docs]
def add_noise(vis: jnp.ndarray, noise_std: jnp.ndarray, key: jnp.ndarray):
"""
Add complex gaussian noise to the integrated visibilities. The real and
imaginary components will each get this level of noise.
Parameters
----------
vis: ndarray (n_time, n_bl, n_freq)
The visibilities to add noise to.
noise_std: (n_freq, )
Standard deviation of the complex noise.
key: jax.random.PRNGKey
Random number generator key.
"""
vis = jnp.asarray(vis)
noise_std = jnp.asarray(noise_std)
key = jnp.asarray(key)
noise = (
random.normal(key, shape=vis.shape, dtype=jnp.complex128)
* noise_std[None, None, :]
)
return vis + noise, noise
[docs]
def SEFD_to_noise_std(
SEFD: jnp.ndarray, chan_width: jnp.ndarray, int_time: jnp.ndarray
):
"""Calculate the standard deviation of the complex noise in a visibility
given the system equivalent flux density, the channel width and integration time.
Parameters
----------
SEFD: ndarray (n_freq, )
System equivalent flux density in Jy.
chan_width: ndarray (n_time, n_ant, n_freq)
Channel width in Hz.
int_time: float
Integration time in seconds.
Returns
-------
noise_std: ndarray (n_time, n_ant, n_freq)
Standard deviation of the complex noise in a visibility.
"""
SEFD = jnp.asarray(SEFD)
chan_width = jnp.asarray(chan_width)
int_time = jnp.asarray(int_time)
return SEFD / jnp.sqrt(2 * chan_width * int_time)
[docs]
def int_sample_times(times: jnp.ndarray, n_int_samples: int = 1):
"""Calculate the times at which to sample the visibilities given the time centroids.
This shoudl produce `n_int_samples` times per integration time that are evenly
spaced around the time centroid.
Parameters
----------
times: ndarray (n_time, )
The time centroids at which to sample the visibilities.
n_int_samples: int
The number of samples to take per integration time.
Returns
-------
times_fine: ndarray (n_time * n_int_samples, )
The times at which to sample the visibilities.
"""
times = jnp.asarray(times)
n_int_samples = jnp.asarray(n_int_samples) # type: ignore
int_time = times[1] - times[0]
times_fine = (
int_time / (2 * n_int_samples)
+ jnp.arange(
times[0] - int_time / 2,
times[-1] + int_time / 2,
int_time / n_int_samples,
)[: n_int_samples * len(times)]
)
return times_fine
[docs]
def generate_gains(
G0_mean: complex,
G0_std: float,
Gt_std_amp: float,
Gt_std_phase: float,
times: jnp.ndarray,
n_ant: int,
n_freq: int,
key: jnp.ndarray,
):
"""
Generate complex antenna gains. Gain amplitudes and phases
are modelled as linear time-variates. Gains for all antennas at t = 0
are randomly sampled from a Gaussian described by the G0 parameters.
The rate of change of both ampltudes and phases are sampled from a zero
mean Gaussian with standard deviation as provided.
Parameters
----------
G0_mean: complex
Mean of Gaussian at t = 0.
G0_std: float
Standard deviation of Gaussian at t = 0.
Gt_std_amp: float
Standard deviation of Gaussian describing the rate of change in the
gain amplitudes in 1/seconds.
Gt_std_phase: float
Standard deviation of Gaussian describing the rate of change in the
gain phases in rad/seconds.
key: jax.random.PRNGKey
Random number generator key.
"""
G0_mean = jnp.asarray(G0_mean) # type: ignore
G0_std = jnp.asarray(G0_std) # type: ignore
Gt_std_amp = jnp.asarray(Gt_std_amp) # type: ignore
Gt_std_phase = jnp.asarray(Gt_std_phase) # type: ignore
times = jnp.asarray(times) - jnp.asarray(times[0])
n_ant = jnp.asarray(n_ant) # type: ignore
n_freq = jnp.asarray(n_freq) # type: ignore
key = jnp.asarray(key)
G0 = G0_mean * jnp.exp(
1.0j * jnp.pi * (random.uniform(key, (1, n_ant, n_freq)) - 0.5)
)
key, subkey = random.split(key)
gains_noise = G0_std * random.normal(key, (n_ant, n_freq), dtype=jnp.complex128)
key, subkey = random.split(key)
gains_amp = Gt_std_amp * random.normal(key, (1, n_ant, 1)) * (times)[:, None, None]
key, subkey = random.split(key)
gains_phase = (
Gt_std_phase * random.normal(key, (1, n_ant, 1)) * (times)[:, None, None]
)
key, subkey = random.split(key)
gains_ants = G0 + gains_noise + gains_amp * jnp.exp(1.0j * gains_phase)
gains_ants = gains_ants.at[:, -1, :].set(jnp.abs(gains_ants[:, -1, :]))
return gains_ants
[docs]
def apply_gains(
vis_ast: jnp.ndarray,
vis_rfi: jnp.ndarray,
gains: jnp.ndarray,
a1: jnp.ndarray,
a2: jnp.ndarray,
):
"""Apply antenna gains to visibilities.
Parameters
----------
vis_ast: ndarray (n_time, n_bl, n_freq)
The astronomical visibilities.
vis_rfi: ndarray (n_time, n_bl, n_freq)
The RFI visibilities.
gains: ndarray (n_time, n_ant, n_freq)
The antenna gains.
a1: ndarray (n_bl,)
The first antenna index for each baseline.
a2: ndarray (n_bl,)
The second antenna index for each baseline.
Returns
-------
vis_obs: ndarray (n_time, n_bl, n_freq)
The visibilities with gains applied.
"""
vis_ast = jnp.asarray(vis_ast)
vis_rfi = jnp.asarray(vis_rfi)
gains = jnp.asarray(gains)
a1 = jnp.asarray(a1)
a2 = jnp.asarray(a2)
vis_obs = gains[:, a1] * (vis_ast + vis_rfi) * jnp.conj(gains[:, a2])
return vis_obs
[docs]
@partial(jit, static_argnums=(1,))
def time_avg(vis: jnp.ndarray, n_int_samples: int = 1):
"""Average visibilities in time.
Parameters
----------
vis: ndarray (n_time_fine, n_bl, n_freq)
The visibilities to average in time.
n_int_samples: int
The number of samples to take per integration time.
Returns
-------
vis_avg: ndarray (n_time, n_bl, n_freq)
The averaged visibilities.
"""
vis = jnp.asarray(vis)
# n_int_samples = jnp.asarray(n_int_samples)
vis_avg = jnp.mean(
jnp.reshape(vis, (-1, n_int_samples, vis.shape[1], vis.shape[2])),
axis=1,
)
return vis_avg
[docs]
def db_to_lin(dB: float):
"""
Convert deciBels to linear units.
Parameters
----------
dB: float, ndarray
deciBel value to convert.
Returns
-------
lin: float, ndarray
"""
dB = jnp.asarray(dB) # type: ignore
return 10.0 ** (dB / 10.0)