documents/tutorials/optimize_voigt.rst
Optimization of a Voigt profile
===============================
.. code:: ipython3
from exojax.spec.lpf import voigt
import jax.numpy as jnp
import matplotlib.pyplot as plt
Let’s optimize the Voigt function :math:`V(\nu, \beta, \gamma_L)` using
exojax! :math:`V(\nu, \beta, \gamma_L)` is a convolution of a Gaussian
with a STD of :math:`\beta` and a Lorentian with a gamma parameter of
:math:`\gamma_L`.
.. code:: ipython3
nu=jnp.linspace(-10,10,100)
plt.plot(nu, voigt(nu,1.0,2.0)) #beta=1.0, gamma_L=2.0
.. parsed-literal::
[<matplotlib.lines.Line2D at 0x7fb8a0491000>]
.. image:: optimize_voigt_files/optimize_voigt_3_1.png
optimization of a simple absorption model
-----------------------------------------
Next, we try to fit a simple absorption model to mock data. The
absorption model is
$ f= e^{-a
V(:raw-latex:`\nu`,:raw-latex:`\beta`,:raw-latex:`\gamma`\_L)}$
.. code:: ipython3
def absmodel(nu,a,beta,gamma_L):
return jnp.exp(-a*voigt(nu,beta,gamma_L))
Adding a noise…
.. code:: ipython3
from numpy.random import normal
data=absmodel(nu,2.0,1.0,2.0)+normal(0.0,0.01,len(nu))
plt.plot(nu,data,".")
.. parsed-literal::
[<matplotlib.lines.Line2D at 0x7fb8a03dbfd0>]
.. image:: optimize_voigt_files/optimize_voigt_8_1.png
Let’s optimize the multiple parameters
.. code:: ipython3
from jax import grad, vmap
We define the objective function as :math:`obj = |d - f|^2`
.. code:: ipython3
# loss or objective function
def obj(a,beta,gamma_L):
f=data-absmodel(nu,a,beta,gamma_L)
g=jnp.dot(f,f)
return g
.. code:: ipython3
#These are the derivative of the objective function
h_a=grad(obj,argnums=0)
h_beta=grad(obj,argnums=1)
h_gamma_L=grad(obj,argnums=2)
print(h_a(2.0,1.0,2.0),h_beta(2.0,1.0,2.0),h_gamma_L(2.0,1.0,2.0))
.. parsed-literal::
0.010246746 -0.00011916496 -0.0035553267
.. code:: ipython3
from jax import jit
@jit
def step(t,opt_state):
a,beta,gamma_L=get_params(opt_state)
value=obj(a,beta,gamma_L)
grads_a = h_a(a,beta,gamma_L)
grads_beta = h_beta(a,beta,gamma_L)
grads_gamma_L = h_gamma_L(a,beta,gamma_L)
grads=jnp.array([grads_a,grads_beta,grads_gamma_L])
opt_state = opt_update(t, grads, opt_state)
return value, opt_state
def doopt(r0,opt_init,get_params,Nstep):
opt_state = opt_init(r0)
traj=[r0]
for t in range(Nstep):
value, opt_state = step(t, opt_state)
p=get_params(opt_state)
traj.append(p)
return traj, p
Here, we use the ADAM optimizer
.. code:: ipython3
#adam
#from jax.experimental import optimizers #for older versions of JAX
from jax.example_libraries import optimizers
opt_init, opt_update, get_params = optimizers.adam(1.e-1)
r0 = jnp.array([1.5,1.5,1.5])
trajadam, padam=doopt(r0,opt_init,get_params,1000)
Optimized values are given in padam
.. code:: ipython3
padam
.. parsed-literal::
Array([1.9930655 , 0.88781667, 2.0753138 ], dtype=float32)
.. code:: ipython3
traj=jnp.array(trajadam)
plt.plot(traj[:,0],label="$\\alpha$")
plt.plot(traj[:,1],ls="dashed",label="$\\beta$")
plt.plot(traj[:,2],ls="dotted",label="$\\gamma_L$")
plt.xscale("log")
plt.legend()
plt.show()
.. image:: optimize_voigt_files/optimize_voigt_19_0.png
.. code:: ipython3
plt.plot(nu,data,".",label="data")
plt.plot(nu,absmodel(nu,padam[0],padam[1],padam[2]),label="optimized")
plt.show()
.. image:: optimize_voigt_files/optimize_voigt_20_0.png
Using SGD instead…, you need to increase the number of iteration for
convergence
.. code:: ipython3
#sgd
#from jax.experimental import optimizers #for older versions of JAX
from jax.example_libraries import optimizers
opt_init, opt_update, get_params = optimizers.sgd(1.e-1)
r0 = jnp.array([1.5,1.5,1.5])
trajsgd, psgd=doopt(r0,opt_init,get_params,10000)
.. code:: ipython3
traj=jnp.array(trajsgd)
plt.plot(traj[:,0],label="$\\alpha$")
plt.plot(traj[:,1],ls="dashed",label="$\\beta$")
plt.plot(traj[:,2],ls="dotted",label="$\\gamma_L$")
plt.xscale("log")
plt.legend()
plt.show()
.. image:: optimize_voigt_files/optimize_voigt_23_0.png