Smoothness costs in different engines

Hello, I encountered a problem on smoothness cost function when changing across scipy engine and jax engine. With the same code, simply changing the engine from scipy to jax may give very different outputs:

My base case sets Cm=10000, Co=1, Cv=0. When Cx=Cy=Cz=0, both scipy and jax engines output similar results.

However, scipy and jax behave differently when Cx, Cy, and Cz are non-zero.

Under scipy engine, when I gradually tune Cx, Cy, and Cz from 0 to very large value (e.g. 1000000), at the 0 end, there is sharp jump in wind speed and wind direction at the boundary of my observation, as expected, since there is no constraint outside the observation coverage, and those outside observation coverage should keep as initial field. When Cx, Cy, and Cz are extremely large, the entire resultant wind field becomes single direction, and the observation data appears not used in the resultant wind field. At certain values of Cx, Cy, and Cz (roughly between 100 to 1000) in between the extreme ends, the region covered by observations approximately aligned with the observation data, while the boundary between observation data and initial field shows smooth transition. Overall the scipy engine is working to expectation.

On the contrary, the jax engine outputs same/ similar results no matter what values I input for Cx, Cy, and Cz. Even if I further tune down both Cm, Co to 0.0001 while keeping Cx, Cy, and Cz at very large values, the results keep the same.

What should I set Cx, Cy, and Cz in jax engine?

I tried to look into the cost_function and gradient part for smoothness to see what values should be set in jax engine. For the smoothness cost_function, all three engines, scipy, jax, and tensorflow seem using the same algorithm. However, if I understand correctly, for the gradient, it appears all three engines use different algorithm:

Scipy: taking numpy laplace twice along respective direction, essentially forth order derivatives (PyDDA/pydda/cost_functions/_cost_functions_numpy.py at main · openradar/PyDDA · GitHub)

Tensorflow: taking tensorflow gradient of smoothness cost function. Since smoothness function is already second order derivative, the calculated gradient is essentially third order derivatives. Also, unlike scipy, the u,v,w directions maybe mixed as x, y, z terms in cost function already have u,v,w inside. (PyDDA/pydda/cost_functions/_cost_functions_tensorflow.py at main · openradar/PyDDA · GitHub)

Jax: Rather complicated, similar but not exactly equivalent to taking smoothness cost function of smoothness cost function. One important difference maybe that grad_u, grad_v, and grad_w are always positive since there is a square in the final calculation. (PyDDA/pydda/cost_functions/_cost_functions_jax.py at main · openradar/PyDDA · GitHub)

Thanks.

After looking closer into the smoothness cost function, and reading the paper:

It appears to me the intention of the cost function is to calculate the laplacian of the wind field. There are two calculation schemes among the three engines:

Scipy & Jax:

dudx = np.gradient(u, dx, axis=2)
dudy = np.gradient(u, dy, axis=1)
dudz = np.gradient(u, dz, axis=0)
dvdx = np.gradient(v, dx, axis=2)
dvdy = np.gradient(v, dy, axis=1)
dvdz = np.gradient(v, dz, axis=0)
dwdx = np.gradient(w, dx, axis=2)
dwdy = np.gradient(w, dy, axis=1)
dwdz = np.gradient(w, dz, axis=0)

x_term = (
    Cx
    * (
        np.gradient(dudx, dx, axis=2)
        + np.gradient(dvdx, dx, axis=1)
        + np.gradient(dwdx, dx, axis=2)
    )
    ** 2
)
y_term = (
    Cy
    * (
        np.gradient(dudy, dy, axis=2)
        + np.gradient(dvdy, dy, axis=1)
        + np.gradient(dwdy, dy, axis=2)
    )
    ** 2
)
z_term = (
    Cz
    * (
        np.gradient(dudz, dz, axis=2)
        + np.gradient(dvdz, dz, axis=1)
        + np.gradient(dwdz, dz, axis=2)
    )
    ** 2
)

Tensor flow:

dudx = _tf_gradient(u, dx, axis=2)
dudy = _tf_gradient(u, dy, axis=1)
dudz = _tf_gradient(u, dz, axis=0)
dvdx = _tf_gradient(v, dx, axis=2)
dvdy = _tf_gradient(v, dy, axis=1)
dvdz = _tf_gradient(v, dz, axis=0)
dwdx = _tf_gradient(w, dx, axis=2)
dwdy = _tf_gradient(w, dy, axis=1)
dwdz = _tf_gradient(w, dz, axis=0)

x_term = (
    Cx
    * (
        _tf_gradient(dudx, dx, axis=2)
        + _tf_gradient(dvdx, dx, axis=2)
        + _tf_gradient(dwdx, dx, axis=2)
    )
    ** 2
)
y_term = (
    Cy
    * (
        _tf_gradient(dudy, dy, axis=1)
        + _tf_gradient(dvdy, dy, axis=1)
        + _tf_gradient(dwdy, dy, axis=1)
    )
    ** 2
)
z_term = (
    Cz
    * (
        _tf_gradient(dudz, dz, axis=0)
        + _tf_gradient(dvdz, dz, axis=0)
        + _tf_gradient(dwdz, dz, axis=0)
    )
    ** 2
)

The difference lies in the axis for x,y,z_terms. If the intention is to calculate the laplacian, maybe the version in Tensorflow is correct one? Could some one help confirm? Thanks.

@SunnysChan
The gradient of a laplacian cost function is a bi-Laplacian. All of these codes are doing exactly that. Going by the behaviour you are describing the numpy-based engine is working as intended. The Jax code is calculating the

However, I do see something overlooked in the Jax engine that could be affecting your result. I noticed that the Cx, Cy, and Cz terms are being carried in twice, once when calculating the cost function and then in the second block calculating the cost function’s Laplacian. I suspect this is not correct, and could be causing rounding errors since we are essentially taking derivatives of these terms twice. If you have time, I wonder what happens when you remove the Cx, Cy, and Cz terms in lines 307-338 of _cost_functions_jax.py where the second laplacian operation is performed? The numpy engine looks to be fine.

Thanks for the clarification. I tried removing the Cx, Cy, and Cz in lines 307-338. It does help bringing the smoothness constraints into the jax engine.

There is one interesting observation, though:

When Co and Cm are tuned to very small values (e.g. 0.01 or even smaller), setting Cx, Cy, and Cz to say 20000 will sufficiently bring the resultant wind to single direction region. (While setting Cx, Cy, and Cz to 0 will give “normal“ result where observations are considered in the wind field)

When I set Co, Cm back to my base case Cm=10000, Co=1, with Cx, Cy, and Cz kept at 20000, the resultant wind field is roughly near the ideal region where transition between observations and initial field is smoother. The interesting thing is if I tune up Cx, Cy, and Cz to even larger say 20000000, the effect of smoothness constraint appears to be “capped”/ “saturated”. It did not go to single direction region.

Originally I thought it is the relative strength between different cost function that matters since they are added together. Is there any target value for the total or individual cost?

As a comparison, the scipy engine reached single direction region when Cm=10000, Co=1, Cx=Cy=Cz=20000.

It looks like you are seeing the effect of rounding errors cancelling out the effect of Co and Cm on the gradient when they are very small. The machine precision will essentially just factor in smoothness, hence your constant wind field. However, when Co and Cm are larger (closer to the magnitude of Cxyz), this becomes less of an issue.

I also took a second look at the Jax smoothness function and the Cx, Cy, and Cz terms in 266-292 also don’t belong there, those terms are already in the final line of the function. I hope getting rid of those brings the numpy and jax engines to consistency. Please let me know if that improves things. If so, would you mind submitting a PR for your fixes to the jax engine?

Thanks for the suggestion. After deleting further the Cx, Cy, and Cz in lines 266-292, the result aligns better with that of scipy (numpy). Within jax engine, there are now noticeable differences when Cx, Cy, and Cz are set from 0 to very large value.

Sure, I will prepare PR on it. Thank you.

Oh so sorry was busy with other stuffs few weeks back, and did not submit the PR until now, but just found that you updated the code last week. Thanks a lot :slight_smile: