Source code for TidalPy.utilities.math.numba_special

""" This module provides several special functions that are specifically designed to work with TidalPy and its
dependencies (looking at you, Numba).

"""

from typing import TYPE_CHECKING

import numpy as np

from TidalPy.utilities.performance import njit, use_numba

if TYPE_CHECKING:
    from TidalPy.utilities.types import NumArray


def _sqrt_neg_python(z: 'NumArray', is_real: bool = False) -> 'NumArray':
    """ Square root - Allows for negative values

    Parameters
    ----------
    z : FloatArray
        Input value (domain is all positive and negative numbers)

    Returns
    -------
    z_sqrt : FloatArray
        Output value (range is all positive values and complex numbers)

    """

    if is_real:
        # First solve the square root assuming z is positive.
        z_sqrt_abs = np.sqrt(np.abs(z))

        # Now correct the negatives (this will either be a Boolean or an array of Booleans, depending upon input type)
        z_sqrt = (np.real(z) > 0.) * z_sqrt_abs + \
                 (np.real(z) < 0.) * z_sqrt_abs * 1.0j

    else:
        # This is a more "complex" process because the input could already be both negative AND complex.
        z_r = np.real(z)
        z_i = np.imag(z)
        quad = np.sqrt(z_r * z_r + z_i * z_i)

        real_part = np.sqrt((quad + z_r) / 2.)
        imag_part = np.sqrt((quad - z_r) / 2.)

        z_sqrt = real_part + \
                 (z_i != 0.) * imag_part * np.sign(z_i) * 1.0j + \
                 (z_i == 0.) * imag_part * 1.0j

    return z_sqrt


# Imaginary square roots
if use_numba:
    # TODO: Numba currently does not support wrapping np.lib.scimath.sqrt, so we have to define our own function.
    #    However, numba.njit of np.sqrt is about 10x faster than np.sqrt with floats and about 2x fast with arrays.
    #    So the above method is actually pretty efficient.
    sqrt_neg = njit(cacheable=True)(_sqrt_neg_python)
else:
    # Numpy already has a built in function to handle this. Use it instead
[docs] def sqrt_neg(z, is_real: bool = False): return np.lib.scimath.sqrt(z)