Source code for TidalPy.utilities.integration.julia_helper

""" Helper functions to interface with Julia / Diffeqpy's integration suite """

from typing import Tuple

import numpy as np

from . import _de, _ode, julia_installed

# Read more about Julia's ode solvers here: https://diffeq.sciml.ai/dev/solvers/ode_solve/

known_integration_methods = (
    'rk4', 'rk45', 'tsit5', 'rko65', 'tsitpap8', 'feagin10', 'feagin12', 'feagin14', 'bs3', 'bs5', 'vern6', 'vern7',
    'vern8', 'vern9', 'kuttaprk2p5', 'rosenbrock23', 'rodas4', 'rodas5'
    )


[docs] def get_julia_solver(solver_name: str): """ Find a ODE solver in the Julia diffeq package. Read more about Julia's ode solvers here: https://diffeq.sciml.ai/dev/solvers/ode_solve/ Parameters ---------- solver_name : str Name of the Julia ode solver Returns ------- ode_system : solver : """ if not julia_installed: raise ImportError('Julia package not found. Can not load ODE solver.') non_stiff_solvers = { # The canonical Runge-Kutta Order 4 method. Uses a defect control for adaptive stepping using maximum error over the whole interval. 'rk4' : _ode.RK4, 'rk45' : _ode.RK4, # Tsitouras 5/4 Runge-Kutta method. (free 4th order interpolant). 'tsit5' : _ode.Tsit5, # Tsitouras' Runge-Kutta-Oliver 6 stage 5th order method. This method is robust on problems which have a singularity at t=0. 'rko65' : _ode.RKO65, # Tsitouras-Papakostas 8/7 Runge-Kutta method. 'tsitpap8' : _ode.TsitPap8, # Feagin's 10th-order Runge-Kutta method. 'feagin10' : _ode.Feagin10, # Feagin's 12th-order Runge-Kutta method. 'feagin12' : _ode.Feagin12, # Feagin's 14th-order Runge-Kutta method. 'feagin14' : _ode.Feagin14, # Additionally, the following algorithms have a lazy interpolant: # BS5 - Bogacki-Shampine 3/2 Runge-Kutta method. (lazy 5th order interpolant). 'bs3' : _ode.BS3, # BS5 - Bogacki-Shampine 5/4 Runge-Kutta method. (lazy 5th order interpolant). 'bs5' : _ode.BS5, # Verner's "Most Efficient" 6/5 Runge-Kutta method. (lazy 6th order interpolant). 'vern6' : _ode.Vern6, # Verner's "Most Efficient" 7/6 Runge-Kutta method. (lazy 7th order interpolant). 'vern7' : _ode.Vern7, # Verner's "Most Efficient" 8/7 Runge-Kutta method. (lazy 8th order interpolant). 'vern8' : _ode.Vern8, # Verner's "Most Efficient" 9/8 Runge-Kutta method. (lazy 9th order interpolant). 'vern9' : _ode.Vern9, # A 5 parallel, 2 processor explicit Runge-Kutta method of 5th order. # These methods utilize multithreading on the f calls to parallelize the problem. This requires that simultaneous calls to f are thread-safe. 'kuttaprk2p5': _ode.KuttaPRK2p5 } stiff_solvers = { 'rosenbrock23': _de.Rosenbrock23, # The ODEInterface algorithms are the classic Fortran algorithms. While the non-stiff algorithms are superseded by the more featured and higher performance Julia implementations from OrdinaryDiffEq.jl, the stiff solvers such as radau are some of the most efficient methods available (but are restricted for use on arrays of Float64). # Rosenbrock 4(3) method. 'rodas4' : _de.Rodas4, 'rodas5' : _de.Rodas5 } try: solver = non_stiff_solvers[solver_name.lower()] ode_system = _ode except KeyError: try: solver = stiff_solvers[solver_name.lower()] ode_system = _de except KeyError: raise KeyError(f'Unknown Julia Integration Model: {solver_name}.') def julia_integrator( diffeq, time_span: Tuple[float, float], initial_condition: np.ndarray, args: Tuple = None, rtol: float = 1.0e-6, atol: float = 1.0e-8, max_step: float = np.inf, first_step: float = 0., method: int = 1, t_eval: np.ndarray = np.empty((0,), dtype=np.float64) ): # Some inputs are unused. Input structure is kept for consistency del method, max_step, first_step # Change the diffeq to match the desired format def diffeq_julia(u, p, r): # Julia integrator flips the order of the variables for the differential equation. output = diffeq(r, u, *p) return list(output) # Setup Julia ODE problem problem = ode_system.ODEProblem(diffeq_julia, initial_condition, time_span, args) # Solve the ode solution = ode_system.solve(problem, solver(), abstol=atol, reltol=rtol) # Find integration codes message = solution.retcode success = message.lower() in ['default', 'success'] if success: # Pull out y values and transpose so they have the same shape as Scipy's y_results = np.transpose(solution.u) time_domain = solution.t if t_eval is not None: if t_eval.size > 1: # Julia does not have the same t_eval. There is the "saveat" keyword but can cause issues. # So perform an interpolation for the desired radii y_results_reduced = np.zeros((y_results.shape[0], t_eval.size), dtype=initial_condition.dtype) for i in range(y_results.shape[0]): y_results_reduced[i, :] = np.interp(t_eval, time_domain, y_results[i, :]) time_domain = t_eval y_results = y_results_reduced else: time_domain = None y_results = None return time_domain, y_results, success, message return julia_integrator