documents/tutorials/optimize_voigt_JAXopt.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimization of a Voigt profile using JAXopt"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:04.542522Z",
"iopub.status.busy": "2022-10-20T05:48:04.539874Z",
"iopub.status.idle": "2022-10-20T05:48:06.557642Z",
"shell.execute_reply": "2022-10-20T05:48:06.557320Z"
}
},
"outputs": [],
"source": [
"from exojax.spec.lpf import voigt\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import jaxopt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's optimize the Voigt function $V(\\nu, \\beta, \\gamma_L)$ using exojax!\n",
"$V(\\nu, \\beta, \\gamma_L)$ is a convolution of a Gaussian with a STD of $\\beta$ and a Lorentian with a gamma parameter of $\\gamma_L$. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:06.568609Z",
"iopub.status.busy": "2022-10-20T05:48:06.568321Z",
"iopub.status.idle": "2022-10-20T05:48:07.102258Z",
"shell.execute_reply": "2022-10-20T05:48:07.101928Z"
},
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fbcef7569a0>]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"nu=jnp.linspace(-10,10,100)\n",
"plt.plot(nu, voigt(nu,1.0,2.0)) #beta=1.0, gamma_L=2.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## optimization of a simple absorption model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we try to fit a simple absorption model to mock data.\n",
"The absorption model is \n",
"\n",
"$ f= e^{-a V(\\nu,\\beta,\\gamma_L)}$\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:07.104653Z",
"iopub.status.busy": "2022-10-20T05:48:07.104342Z",
"iopub.status.idle": "2022-10-20T05:48:07.105779Z",
"shell.execute_reply": "2022-10-20T05:48:07.106056Z"
}
},
"outputs": [],
"source": [
"def absmodel(nu,a,beta,gamma_L):\n",
" return jnp.exp(-a*voigt(nu,beta,gamma_L))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding a noise...\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:07.108361Z",
"iopub.status.busy": "2022-10-20T05:48:07.108032Z",
"iopub.status.idle": "2022-10-20T05:48:07.310843Z",
"shell.execute_reply": "2022-10-20T05:48:07.310515Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fb90d0398b0>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from numpy.random import normal\n",
"data=absmodel(nu,2.0,1.0,2.0)+normal(0.0,0.01,len(nu))\n",
"plt.plot(nu,data,\".\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's optimize the multiple parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the objective function as $obj = |d - f|^2$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:07.316921Z",
"iopub.status.busy": "2022-10-20T05:48:07.316618Z",
"iopub.status.idle": "2022-10-20T05:48:07.317971Z",
"shell.execute_reply": "2022-10-20T05:48:07.318232Z"
}
},
"outputs": [],
"source": [
"# loss or objective function\n",
"def objective(params):\n",
" a,beta,gamma_L=params\n",
" f=data-absmodel(nu,a,beta,gamma_L)\n",
" g=jnp.dot(f,f)\n",
" return g\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:07.320127Z",
"iopub.status.busy": "2022-10-20T05:48:07.319790Z",
"iopub.status.idle": "2022-10-20T05:48:07.320959Z",
"shell.execute_reply": "2022-10-20T05:48:07.321203Z"
}
},
"outputs": [],
"source": [
"# Gradient Descent"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:07.323961Z",
"iopub.status.busy": "2022-10-20T05:48:07.323658Z",
"iopub.status.idle": "2022-10-20T05:48:11.061847Z",
"shell.execute_reply": "2022-10-20T05:48:11.062106Z"
}
},
"outputs": [],
"source": [
"gd = jaxopt.GradientDescent(fun=objective, maxiter=10)\n",
"res = gd.run(init_params=(1.5,0.7,1.5))\n",
"params, state = res"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:11.064618Z",
"iopub.status.busy": "2022-10-20T05:48:11.064301Z",
"iopub.status.idle": "2022-10-20T05:48:11.066350Z",
"shell.execute_reply": "2022-10-20T05:48:11.066576Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray(1.9579332, dtype=float32, weak_type=True),\n",
" DeviceArray(1.0382165, dtype=float32, weak_type=True),\n",
" DeviceArray(1.8850585, dtype=float32, weak_type=True))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:11.071839Z",
"iopub.status.busy": "2022-10-20T05:48:11.071551Z",
"iopub.status.idle": "2022-10-20T05:48:11.180374Z",
"shell.execute_reply": "2022-10-20T05:48:11.180644Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fb90cf3d490>]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from numpy.random import normal\n",
"model=absmodel(nu,params[0],params[1],params[2])\n",
"plt.plot(nu,model)\n",
"plt.plot(nu,data,\".\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:11.182786Z",
"iopub.status.busy": "2022-10-20T05:48:11.182469Z",
"iopub.status.idle": "2022-10-20T05:48:11.184079Z",
"shell.execute_reply": "2022-10-20T05:48:11.183790Z"
}
},
"outputs": [],
"source": [
"#NCG"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:11.187312Z",
"iopub.status.busy": "2022-10-20T05:48:11.186821Z",
"iopub.status.idle": "2022-10-20T05:48:18.162965Z",
"shell.execute_reply": "2022-10-20T05:48:18.162656Z"
}
},
"outputs": [],
"source": [
"gd = jaxopt.NonlinearCG(fun=objective, maxiter=100)\n",
"res = gd.run(init_params=(1.5,0.7,1.5))\n",
"params, state = res"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:18.165241Z",
"iopub.status.busy": "2022-10-20T05:48:18.164923Z",
"iopub.status.idle": "2022-10-20T05:48:18.167385Z",
"shell.execute_reply": "2022-10-20T05:48:18.167620Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray(1.9526778, dtype=float32),\n",
" DeviceArray(1.0492882, dtype=float32),\n",
" DeviceArray(1.8708111, dtype=float32))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-20T05:48:18.172661Z",
"iopub.status.busy": "2022-10-20T05:48:18.172366Z",
"iopub.status.idle": "2022-10-20T05:48:18.469013Z",
"shell.execute_reply": "2022-10-20T05:48:18.469353Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fb90c0d6eb0>]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from numpy.random import normal\n",
"model=absmodel(nu,params[0],params[1],params[2])\n",
"plt.plot(nu,model)\n",
"plt.plot(nu,data,\".\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.8 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "72bc7f8b1808a6f5ada3c6a20601509b8b1843160436d276d47f2ba819b3753b"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}