Parameters
Parameter
Bases: Variable[T]
Parameter base class.
All trainable parameters in GPJax should inherit from this class.
NonNegativeReal
PositiveReal
SigmoidBounded
LowerTriangular
transform
Transforms parameters using a bijector.
Example
from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"].value) [1.3132617]
Parameters:
-
params
(State
) βA nnx.State object containing parameters to be transformed.
-
params_bijection
(Dict[str, Transform]
) βA dictionary mapping parameter types to bijectors.
-
inverse
(bool
, default:False
) βWhether to apply the inverse transformation.
Returns:
-
State
(State
) βA new nnx.State object containing the transformed parameters.