Hello, I encountered a problem on smoothness cost function when changing across scipy engine and jax engine. With the same code, simply changing the engine from scipy to jax may give very different outputs:
My base case sets Cm=10000, Co=1, Cv=0. When Cx=Cy=Cz=0, both scipy and jax engines output similar results.
However, scipy and jax behave differently when Cx, Cy, and Cz are non-zero.
Under scipy engine, when I gradually tune Cx, Cy, and Cz from 0 to very large value (e.g. 1000000), at the 0 end, there is sharp jump in wind speed and wind direction at the boundary of my observation, as expected, since there is no constraint outside the observation coverage, and those outside observation coverage should keep as initial field. When Cx, Cy, and Cz are extremely large, the entire resultant wind field becomes single direction, and the observation data appears not used in the resultant wind field. At certain values of Cx, Cy, and Cz (roughly between 100 to 1000) in between the extreme ends, the region covered by observations approximately aligned with the observation data, while the boundary between observation data and initial field shows smooth transition. Overall the scipy engine is working to expectation.
On the contrary, the jax engine outputs same/ similar results no matter what values I input for Cx, Cy, and Cz. Even if I further tune down both Cm, Co to 0.0001 while keeping Cx, Cy, and Cz at very large values, the results keep the same.
What should I set Cx, Cy, and Cz in jax engine?
I tried to look into the cost_function and gradient part for smoothness to see what values should be set in jax engine. For the smoothness cost_function, all three engines, scipy, jax, and tensorflow seem using the same algorithm. However, if I understand correctly, for the gradient, it appears all three engines use different algorithm:
Scipy: taking numpy laplace twice along respective direction, essentially forth order derivatives (PyDDA/pydda/cost_functions/_cost_functions_numpy.py at main · openradar/PyDDA · GitHub)
Tensorflow: taking tensorflow gradient of smoothness cost function. Since smoothness function is already second order derivative, the calculated gradient is essentially third order derivatives. Also, unlike scipy, the u,v,w directions maybe mixed as x, y, z terms in cost function already have u,v,w inside. (PyDDA/pydda/cost_functions/_cost_functions_tensorflow.py at main · openradar/PyDDA · GitHub)
Jax: Rather complicated, similar but not exactly equivalent to taking smoothness cost function of smoothness cost function. One important difference maybe that grad_u, grad_v, and grad_w are always positive since there is a square in the final calculation. (PyDDA/pydda/cost_functions/_cost_functions_jax.py at main · openradar/PyDDA · GitHub)
Thanks.