Source code for lasy.utils.phase_retrieval
import copy
from lasy.backend import xp
[docs]
def gerchberg_saxton_algo(
laserPos1,
laserPos2,
dz,
condition="max_iterations",
max_iterations=10,
amplitude_error=1.0,
debug=False,
):
"""
Implement the Gerchberg-Saxton Algorithm.
Given two laser profiles and a distance betweent them, calculate
the spatial phase profile of the laser in both planes using the
Gerchberg-Saxton algorithm. Returns the phase of the laser in both
planes.
Parameters
----------
laserPos1, laserPos2 : instance of Laser
dz : float (meters)
Distance between the two laser pulses
condition : string
The condition on which to stop the iterative loop.
Options are: 'max_iterations', stopping after a fixed number of
iterations or 'amplitude_error', waiting for the residual of
the e-field amplitude from laserPos1 to fall below 'amplitude_error'
max_iterations : int
Maximum number of iterations to perform
amplitude_error : float
Residual value for amplitude given as a fraction of the maximum
initial amplitude value for laserPos1.
debug : boolean (default: False)
if True, the error at each iteration is printed to standard output
Returns
-------
phase1, phase2 : ndarray of floats (rad)
Phase profiles of the laser pulse at the locations where
laserPos1 and laserPos2 are defined.
"""
laser1 = copy.deepcopy(laserPos1)
laser2 = copy.deepcopy(laserPos2)
amp1 = xp.abs(laser1.grid.get_temporal_field())
amp1_summed = xp.sum(amp1)
amp2 = xp.abs(laser2.grid.get_temporal_field())
phase1 = xp.zeros_like(amp1)
if condition == "max_iterations":
def breakout(i):
return i < max_iterations
cond = 0
elif condition == "amplitude_error":
def breakout(amp):
return amp / amp1_summed > amplitude_error
cond = 9e30
i = 0
while breakout(cond):
laser1.grid.set_temporal_field(amp1 * xp.exp(1j * phase1))
laser1.propagate(dz, verbose=False)
phase2 = xp.angle(laser1.grid.get_temporal_field())
laser2.grid.set_temporal_field(amp2 * xp.exp(1j * phase2))
laser2.propagate(-dz, verbose=False)
field2 = laser2.grid.get_temporal_field()
phase1 = xp.angle(field2)
amp1_calc = xp.abs(field2)
amp_error_summed = xp.sum(xp.abs(amp1_calc) - amp1)
if debug:
i += 1
print(
"Iteration %i : Amplitude Error (summed) = %.2e"
% (i, amp_error_summed / amp1_summed)
)
if condition == "max_iterations":
cond += 1
else:
cond = amp_error_summed
return phase1, phase2, amp_error_summed