pollux.models#
Root module#
- class pollux.models.Cannon(label_size: int, output_size: int, poly_degree: int = 2, include_bias: bool = True, coeffs: Array | None = None, scatter: Array | None = None)#
Bases:
ModuleThe Cannon: a data-driven model for stellar spectra.
Learns a polynomial relationship between stellar labels and spectra (or other outputs). For each output pixel, the model fits:
output_λ = Σ_j θ_{λj} · feature_j(labels)
where the features are polynomial combinations of the labels up to the specified degree.
- Parameters:
label_size (
int) – Number of input labels (e.g., 3 for Teff, logg, [Fe/H]).output_size (
int) – Number of output dimensions (e.g., number of spectral pixels).poly_degree (
int) – Maximum polynomial degree for feature expansion. Default is 2.include_bias (
bool) – Whether to include a bias term in the polynomial features. Default is True.coeffs (Array | None)
scatter (Array | None)
- coeffs#
Fitted coefficients, shape
(output_size, n_features). None before fitting.
- scatter#
Fitted per-pixel scatter, shape
(output_size,). None before fitting.
- n_features#
Number of polynomial features (computed from label_size and poly_degree).
Examples
>>> import jax.numpy as jnp >>> from pollux.models import Cannon
Create a Cannon model for 3 labels and 100 spectral pixels:
>>> cannon = Cannon(label_size=3, output_size=100, poly_degree=2) >>> cannon.n_features # 1 + 3 + 6 = 10 for degree 2 with 3 labels 10
The number of features follows the formula for combinations with replacement: C(n_labels + degree, degree) = C(3 + 2, 2) = 10
- Parameters:
- fit(labels: Array, output: Array, output_ivar: Array | None = None, regularization: float = 0.0)#
Fit the Cannon using weighted least squares.
For each output pixel, solves the weighted least squares problem:
argmin_θ Σ_i w_i (y_i - f_i @ θ)^2 + λ ||θ||^2
- Parameters:
labels (
Array) – Training stellar labels, shape(n_stars, label_size).output (
Array) – Training output (e.g., spectra), shape(n_stars, output_size).output_ivar (
Array|None) – Inverse variance of the output. Shape(n_stars, output_size). If None, uniform weights (1.0) are used.regularization (
float) – L2 regularization strength (λ). Default is 0.0 (no regularization). Larger values shrink coefficients toward zero.
- Returns:
A new Cannon instance with fitted coefficients and scatter.
- Return type:
Notes
This method uses JAX’s vmap for efficient vectorized fitting across all pixels. The solution for each pixel is:
θ = (F.T @ W @ F + λI)^{-1} @ F.T @ W @ y
where F is the design matrix (polynomial features), W is a diagonal weight matrix, and y is the output vector.
Examples
>>> labels = jnp.array([[5000., 4.0], [5500., 3.5], [6000., 4.5]]) >>> spectra = jnp.array([[1.0, 2.0], [1.5, 2.5], [2.0, 3.0]]) >>> cannon = Cannon(label_size=2, output_size=2, poly_degree=1) >>> cannon = cannon.fit(labels, spectra) >>> cannon.is_fitted True
- get_coeffs_as_transform_pars()#
Get fitted coefficients in transform parameter format.
Returns the fitted coefficients in the format expected by TransformSequence/LuxModel. This allows using Cannon-fitted parameters as initial values or fixed parameters in LuxModel.
- Returns:
Parameter dictionary in the format:
{"data": [{"A": coeffs.T}]}The coefficients are transposed because LinearTransform expects shape
(output_size, latent_size)where latent_size = n_features.- Return type:
- Raises:
RuntimeError – If the model has not been fitted.
Examples
>>> cannon = cannon.fit(labels, spectra) >>> pars = cannon.get_coeffs_as_transform_pars() >>> # Use with LuxModel >>> model.predict_outputs(labels, {"flux": pars})
- get_features(labels: Array)#
Expand labels into polynomial features.
- Parameters:
labels (
Array) – Stellar labels, shape(n_stars, label_size).- Returns:
Polynomial features, shape
(n_stars, n_features).- Return type:
array
Examples
>>> labels = jnp.array([[1.0, 2.0]]) # 1 star, 2 labels >>> cannon = Cannon(label_size=2, output_size=10, poly_degree=2) >>> features = cannon.get_features(labels) >>> features.shape (1, 6) >>> features Array([[1., 1., 2., 1., 2., 4.]], dtype=float...)
- predict(labels: Array)#
Predict output for given labels.
- Parameters:
labels (
Array) – Stellar labels, shape(n_stars, label_size).- Returns:
Predicted output, shape
(n_stars, output_size).- Return type:
array
- Raises:
RuntimeError – If the model has not been fitted.
Examples
>>> cannon = cannon.fit(train_labels, train_spectra) >>> predicted = cannon.predict(test_labels)
- to_transform_sequence()#
Convert to a TransformSequence for use with LuxModel.
Returns a TransformSequence that can be used with LuxModel for Bayesian inference or more complex models. The sequence consists of:
PolyFeatureTransform: labels → polynomial features (no learnable params)
LinearTransform: features → output (learnable A matrix)
- Returns:
A transform sequence that can be registered with LuxModel.
- Return type:
Notes
This method creates a new TransformSequence where the LinearTransform’s A matrix will be sampled from priors during numpyro inference. If the Cannon has been fitted, you can use the fitted coefficients as initial values or fixed parameters.
Examples
>>> import pollux as plx >>> cannon = Cannon(label_size=3, output_size=128, poly_degree=2) >>> transform = cannon.to_transform_sequence() >>> model = plx.LuxModel(latent_size=3) # latent_size = label_size >>> model.register_output("flux", transform)
- class pollux.models.Lux(latent_size: int)#
Bases:
ModuleA latent variable model with multiple outputs.
Lux is a generative, latent variable model for output data. This is a general framework for constructing multi-output or multi-task models in which the output data is generated as a transformation away from some embedded vector representation of each object. While this class and model structure can be used in a broad range of applications, this package and implementation was written with applications to stellar spectroscopic data in mind.
- Parameters:
latent_size (int) – The size of the latent vector representation for each object (i.e. the embedded dimensionality).
Notes
Parameter Format
The
optimize()method returns parameters in a nested format:{ "output_name": { "data": {"A": array, ...}, # Transform parameters "err": {"s": array, ...} # Error transform parameters }, "latents": array # Per-object latent vectors }
This same format should be used when passing parameters to
predict_outputs().Naming Restrictions
Output names and transform parameter names cannot contain colons (
':') as they are reserved for internal parameter naming in numpyro.- default_numpyro_model(data: PolluxData, latents_prior: Distribution | None | bool = None, fixed_pars: dict[str, Any] | None = None, names: list[str] | None = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None)#
Create the default numpyro model for this Lux model.
The default model uses the specified latent vector prior and assumes that the data are Gaussian distributed away from the true (predicted) values given the specified errors.
- Parameters:
data (
PolluxData) – A dictionary of observed data.latents_prior (
Distribution|None|bool) – The prior distribution for the latent vectors. If not specified, use a unit Gaussian. If False, use an improper uniform prior.fixed_pars (
dict[str,Any] |None) – A dictionary of fixed parameters to condition on. If None, all parameters will be sampled.names (
list[str] |None) – A list of output names to include in the model. If None, include all outputs.custom_model (
Callable[[Float[Array, '#stars latents'],dict[str,Any],PolluxData],None] |None) – Optional callable that takes latents, pars, and data and adds custom modeling components.
- Return type:
- optimize(data: PolluxData, num_steps: int, rng_key: Array, optimizer: _NumPyroOptim | Optimizer | Any | None = None, latents_prior: Distribution | None | bool = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None, fixed_pars: dict[str, dict[str, Any] | Array] | None = None, names: list[str] | None = None, svi_run_kwargs: dict[str, Any] | None = None, guide: type[AutoGuide] | AutoGuide | None = None)#
Optimize the model parameters using SVI.
- Parameters:
data (
PolluxData) – The observed data to optimize against.num_steps (
int) – Number of SVI optimization steps.rng_key (
Array) – JAX random key for the optimization.optimizer (
_NumPyroOptim|Optimizer|Any|None) – Numpyro optimizer to use. Defaults tonumpyro.optim.Adam().latents_prior (
Distribution|None|bool) – Prior distribution for the latent vectors. IfNone, uses a unit Gaussian. IfFalse, uses an improper uniform prior.custom_model (
Callable[[Float[Array, '#stars latents'],dict[str,Any],PolluxData],None] |None) – Optional callable for custom modeling components.fixed_pars (
dict[str,dict[str,Any] |Array] |None) – Parameters to hold fixed during optimization.names (
list[str] |None) – Output names to include. IfNone, includes all outputs.svi_run_kwargs (
dict[str,Any] |None) – Additional keyword arguments passed toSVI.run().guide (
type[AutoGuide] |AutoGuide|None) –The autoguide to use for variational inference. Can be:
None(default): usesAutoDeltafor MAP estimation.A guide class (e.g.
AutoNormal): will be instantiated with the model function.A guide instance: used directly (must already be constructed with the model function).
- Return type:
- optimize_iterative(data: PolluxData, blocks: list[ParameterBlock] | list[str] | None = None, fixed_pars: dict[str, dict[str, Any] | Array] | None = None, max_cycles: int = 10, tol: float = 0.0001, rng_key: Array | None = None, initial_params: dict[str, dict[str, Any] | Array] | None = None, latents_prior: Distribution | None = None, progress: bool = True, record_history: bool = False)#
Optimize using iterative parameter block coordinate descent.
For models with purely linear outputs, this method exploits the linear structure for faster convergence. For linear transforms, each sub-problem is solved exactly using weighted least squares.
The default strategy alternates between: 1. Optimize latents (with output parameters fixed) 2. Optimize each output’s parameters (with latents fixed)
- Parameters:
data (
PolluxData) – The training data.blocks (
list[ParameterBlock] |list[str] |None) – List ofParameterBlockspecifications, or a list of strings naming which parameter groups to optimize (e.g.["latents"]). When strings are provided,ParameterBlockinstances are constructed automatically with an inferred optimizer. If None, uses a default strategy that alternates between latents and each output.fixed_pars (
dict[str,dict[str,Any] |Array] |None) – Parameters to hold fixed during optimization. When provided alongside stringblocks, the function initializes the optimized parameters (e.g. latents to zero) and mergesfixed_parswith them before returning, soresult.paramsis a complete parameter dict.max_cycles (
int) – Maximum number of full optimization cycles.tol (
float) – Convergence tolerance. Stops when relative change in loss < tol.rng_key (
Array|None) – JAX random key. Required when any block uses SVI (i.e.,optimizer != "least_squares"). If None andinitial_paramsis also None, falls back tojax.random.PRNGKey(0)for initialization from priors.initial_params (
dict[str,dict[str,Any] |Array] |None) – Initial parameter values. If None andfixed_parsis provided, built automatically. If both are None, initialized from priors.latents_prior (
Distribution|None) – Prior distribution for latents. If None, uses Normal(0, 1). Used to determine regularization strength for latent least squares.progress (
bool) – Whether to display a tqdm progress bar showing optimization progress.record_history (
bool) – Whether to record detailed per-block loss history.
- Returns:
The optimization result containing: -
params: Optimized parameters in unpacked format (includes fixedparams when
fixed_parsis provided)losses_per_cycle: Loss values at the end of each cyclen_cycles: Number of cycles completedconverged: Whether optimization convergedhistory: Optional detailed history (if record_history=True)
- Return type:
IterativeOptimizationResult
Notes
For blocks with linear transforms (
LinearTransform,AffineTransform,OffsetTransform), each sub-problem is solved exactly via weighted least squares. For non-linear transforms, SVI is used withnumpyro.optim.Adamatstep_size=1e-3by default; override viaoptimizer_kwargson the block, e.g.ParameterBlock(..., optimizer_kwargs={"step_size": 1e-4}).Regularization is automatically extracted from the priors on the transform parameters.
Examples
Basic usage:
>>> result = model.optimize_iterative(data, max_cycles=20) >>> opt_params = result.params
With custom blocks:
>>> from pollux.models import ParameterBlock >>> blocks = [ ... ParameterBlock("latents", "latents", optimizer="least_squares"), ... ParameterBlock("flux", "flux:data", optimizer="least_squares"), ... ] >>> result = model.optimize_iterative(data, blocks=blocks)
Optimizing only latents with fixed output parameters (e.g. applying a trained model to test data):
>>> result = model.optimize_iterative( ... test_data, blocks=["latents"], fixed_pars=trained_pars ... ) >>> test_opt_pars = result.params # contains fixed + optimized params
- pack_numpyro_pars(pars: dict[str, dict[str, Any] | Array], ignore_missing: bool = False)#
Pack parameters into a flat dictionary keyed on numpyro names.
This method is the inverse of unpack_numpyro_pars. It takes a nested dictionary of parameters and flattens it into a dictionary keyed on numpyro parameter names.
- Parameters:
pars (
dict[str,dict[str,Any] |Array]) –A nested dictionary with keys as output names. Each output name should be a key with a dict value containing “data” and optionally “err” keys. The “err” key can be omitted if there are no error parameters for that output. For TransformSequence outputs, “data” values should be tuples/lists of parameter dictionaries. Non-output parameters (like “latents”) can exist at the top level.
Example structure: {
”flux”: {“data”: {…} or (…)}, # err key optional “label”: {“data”: {…}, “err”: {…}}, # err key included “latents”: array
}
ignore_missing (
bool)
- Returns:
A dictionary of numpyro parameters. The keys are in the format “output_name:param_name” for data parameters and “output_name:err:param_name” for error parameters.
- Return type:
- predict_outputs(latents: Float[Array, '#stars latents'], pars: dict[str, Any], names: list[str] | str | None = None)#
Predict output values for given latent vectors and parameters.
- Parameters:
latents (
Float[Array, '#stars latents']) – The latent vectors that transform into the outputs. Shape should be(n_objects, latent_size).A dictionary of parameters for each output transformation in the model. Should be in the nested format returned by
optimize():{ "output_name": { "data": {...} or [...], # Transform parameters "err": {...} # Error transform parameters }, "latents": array # Optional, not used here }
For single transforms,
"data"is a dict:{"A": array, "b": array}For
TransformSequence,"data"is a tuple of dicts:({"A": array}, {"b": array})Deprecated since version Passing: parameters in direct format (without the
"data"/"err"wrapper) is deprecated and will be removed in a future version.names (
list[str] |str|None) – A single string or a list of output names to predict. IfNone, predict all outputs (default).
- Returns:
A dictionary of predicted output values, where the keys are the output names and values are arrays of shape
(n_objects, output_size).- Return type:
- register_output(name: str, data_transform: AbstractSingleTransform | TransformSequence, err_transform: AbstractSingleTransform | TransformSequence | None = None)#
Register a new output of the model given a specified transform.
- Parameters:
name (
str) – The name of the output. If you intend to use this model with numpyro and specified data, this name should correspond to the name of data passed in via a pollux.data.PolluxData object. The name cannot contain colons (‘:’) as they are reserved for internal parameter naming.data_transform (
AbstractSingleTransform|TransformSequence) – A specification of the transformation function that takes a latent vector representation in and predicts the output values.err_transform (
AbstractSingleTransform|TransformSequence|None)
- Return type:
- setup_numpyro(latents: Float[Array, '#stars latents'], data: PolluxData, names: list[str] | None = None)#
Sample parameters and set up basic numpyro model.
- Parameters:
latents (
Float[Array, '#stars latents']) – The latent vectors that transform into the outputs. In the case of the Paton, these are the (unknown) latent vectors. In the case of the Cannon, these are the observed latents for the training set (combinations of stellar labels).data (
PolluxData) – A dictionary-like object of observed data for each output. The keys should correspond to the output names.names (
list[str] |None) – A single string or a list of output names to set up. If None, set up all outputs (default).
- Returns:
A dictionary of sampled parameters for each output.
- Return type:
- unpack_numpyro_pars(pars: dict[str, Any], ignore_missing: bool = False)#
Unpack numpyro parameters into separate data and error parameter structures.
numpyro parameters use names like “output_name:param_name” to make the numpyro internal names unique. This method unpacks these into two nested dictionaries: one for data transform parameters and one for error transform parameters.
For TransformSequence outputs, data parameters are further unpacked from the flattened “{index}:{param}” format into a tuple of parameter dictionaries.
- Parameters:
- Returns:
A nested dictionary with keys as output names. Each output name is a key with a dict value containing “data” and “err” keys: - For single transforms, “data” values are parameter dictionaries - For TransformSequence, “data” values are tuples of parameter dictionaries - “err” values follow the same structure as “data” for the error transforms - “err” will be an empty dict {} if there are no error parameters - Non-output parameters (like “latents”) are passed through at the top level
Example structure: {
”flux”: {“data”: {…} or (…), “err”: {}}, # err empty if no error pars “label”: {“data”: {…}, “err”: {…}}, “latents”: array
}
- Return type:
- class pollux.models.LuxModel(*args: Any, **kwargs: Any)#
Bases:
LuxDeprecated alias for Lux class.
Deprecated since version Use:
Luxinstead.LuxModelwill be removed in a future version.- default_numpyro_model(data: PolluxData, latents_prior: Distribution | None | bool = None, fixed_pars: dict[str, Any] | None = None, names: list[str] | None = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None)#
Create the default numpyro model for this Lux model.
The default model uses the specified latent vector prior and assumes that the data are Gaussian distributed away from the true (predicted) values given the specified errors.
- Parameters:
data (
PolluxData) – A dictionary of observed data.latents_prior (
Distribution|None|bool) – The prior distribution for the latent vectors. If not specified, use a unit Gaussian. If False, use an improper uniform prior.fixed_pars (
dict[str,Any] |None) – A dictionary of fixed parameters to condition on. If None, all parameters will be sampled.names (
list[str] |None) – A list of output names to include in the model. If None, include all outputs.custom_model (
Callable[[Float[Array, '#stars latents'],dict[str,Any],PolluxData],None] |None) – Optional callable that takes latents, pars, and data and adds custom modeling components.
- Return type:
- optimize(data: PolluxData, num_steps: int, rng_key: Array, optimizer: _NumPyroOptim | Optimizer | Any | None = None, latents_prior: Distribution | None | bool = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None, fixed_pars: dict[str, dict[str, Any] | Array] | None = None, names: list[str] | None = None, svi_run_kwargs: dict[str, Any] | None = None, guide: type[AutoGuide] | AutoGuide | None = None)#
Optimize the model parameters using SVI.
- Parameters:
data (
PolluxData) – The observed data to optimize against.num_steps (
int) – Number of SVI optimization steps.rng_key (
Array) – JAX random key for the optimization.optimizer (
_NumPyroOptim|Optimizer|Any|None) – Numpyro optimizer to use. Defaults tonumpyro.optim.Adam().latents_prior (
Distribution|None|bool) – Prior distribution for the latent vectors. IfNone, uses a unit Gaussian. IfFalse, uses an improper uniform prior.custom_model (
Callable[[Float[Array, '#stars latents'],dict[str,Any],PolluxData],None] |None) – Optional callable for custom modeling components.fixed_pars (
dict[str,dict[str,Any] |Array] |None) – Parameters to hold fixed during optimization.names (
list[str] |None) – Output names to include. IfNone, includes all outputs.svi_run_kwargs (
dict[str,Any] |None) – Additional keyword arguments passed toSVI.run().guide (
type[AutoGuide] |AutoGuide|None) –The autoguide to use for variational inference. Can be:
None(default): usesAutoDeltafor MAP estimation.A guide class (e.g.
AutoNormal): will be instantiated with the model function.A guide instance: used directly (must already be constructed with the model function).
- Return type:
- optimize_iterative(data: PolluxData, blocks: list[ParameterBlock] | list[str] | None = None, fixed_pars: dict[str, dict[str, Any] | Array] | None = None, max_cycles: int = 10, tol: float = 0.0001, rng_key: Array | None = None, initial_params: dict[str, dict[str, Any] | Array] | None = None, latents_prior: Distribution | None = None, progress: bool = True, record_history: bool = False)#
Optimize using iterative parameter block coordinate descent.
For models with purely linear outputs, this method exploits the linear structure for faster convergence. For linear transforms, each sub-problem is solved exactly using weighted least squares.
The default strategy alternates between: 1. Optimize latents (with output parameters fixed) 2. Optimize each output’s parameters (with latents fixed)
- Parameters:
data (
PolluxData) – The training data.blocks (
list[ParameterBlock] |list[str] |None) – List ofParameterBlockspecifications, or a list of strings naming which parameter groups to optimize (e.g.["latents"]). When strings are provided,ParameterBlockinstances are constructed automatically with an inferred optimizer. If None, uses a default strategy that alternates between latents and each output.fixed_pars (
dict[str,dict[str,Any] |Array] |None) – Parameters to hold fixed during optimization. When provided alongside stringblocks, the function initializes the optimized parameters (e.g. latents to zero) and mergesfixed_parswith them before returning, soresult.paramsis a complete parameter dict.max_cycles (
int) – Maximum number of full optimization cycles.tol (
float) – Convergence tolerance. Stops when relative change in loss < tol.rng_key (
Array|None) – JAX random key. Required when any block uses SVI (i.e.,optimizer != "least_squares"). If None andinitial_paramsis also None, falls back tojax.random.PRNGKey(0)for initialization from priors.initial_params (
dict[str,dict[str,Any] |Array] |None) – Initial parameter values. If None andfixed_parsis provided, built automatically. If both are None, initialized from priors.latents_prior (
Distribution|None) – Prior distribution for latents. If None, uses Normal(0, 1). Used to determine regularization strength for latent least squares.progress (
bool) – Whether to display a tqdm progress bar showing optimization progress.record_history (
bool) – Whether to record detailed per-block loss history.
- Returns:
The optimization result containing: -
params: Optimized parameters in unpacked format (includes fixedparams when
fixed_parsis provided)losses_per_cycle: Loss values at the end of each cyclen_cycles: Number of cycles completedconverged: Whether optimization convergedhistory: Optional detailed history (if record_history=True)
- Return type:
IterativeOptimizationResult
Notes
For blocks with linear transforms (
LinearTransform,AffineTransform,OffsetTransform), each sub-problem is solved exactly via weighted least squares. For non-linear transforms, SVI is used withnumpyro.optim.Adamatstep_size=1e-3by default; override viaoptimizer_kwargson the block, e.g.ParameterBlock(..., optimizer_kwargs={"step_size": 1e-4}).Regularization is automatically extracted from the priors on the transform parameters.
Examples
Basic usage:
>>> result = model.optimize_iterative(data, max_cycles=20) >>> opt_params = result.params
With custom blocks:
>>> from pollux.models import ParameterBlock >>> blocks = [ ... ParameterBlock("latents", "latents", optimizer="least_squares"), ... ParameterBlock("flux", "flux:data", optimizer="least_squares"), ... ] >>> result = model.optimize_iterative(data, blocks=blocks)
Optimizing only latents with fixed output parameters (e.g. applying a trained model to test data):
>>> result = model.optimize_iterative( ... test_data, blocks=["latents"], fixed_pars=trained_pars ... ) >>> test_opt_pars = result.params # contains fixed + optimized params
- pack_numpyro_pars(pars: dict[str, dict[str, Any] | Array], ignore_missing: bool = False)#
Pack parameters into a flat dictionary keyed on numpyro names.
This method is the inverse of unpack_numpyro_pars. It takes a nested dictionary of parameters and flattens it into a dictionary keyed on numpyro parameter names.
- Parameters:
pars (
dict[str,dict[str,Any] |Array]) –A nested dictionary with keys as output names. Each output name should be a key with a dict value containing “data” and optionally “err” keys. The “err” key can be omitted if there are no error parameters for that output. For TransformSequence outputs, “data” values should be tuples/lists of parameter dictionaries. Non-output parameters (like “latents”) can exist at the top level.
Example structure: {
”flux”: {“data”: {…} or (…)}, # err key optional “label”: {“data”: {…}, “err”: {…}}, # err key included “latents”: array
}
ignore_missing (
bool)
- Returns:
A dictionary of numpyro parameters. The keys are in the format “output_name:param_name” for data parameters and “output_name:err:param_name” for error parameters.
- Return type:
- predict_outputs(latents: Float[Array, '#stars latents'], pars: dict[str, Any], names: list[str] | str | None = None)#
Predict output values for given latent vectors and parameters.
- Parameters:
latents (
Float[Array, '#stars latents']) – The latent vectors that transform into the outputs. Shape should be(n_objects, latent_size).A dictionary of parameters for each output transformation in the model. Should be in the nested format returned by
optimize():{ "output_name": { "data": {...} or [...], # Transform parameters "err": {...} # Error transform parameters }, "latents": array # Optional, not used here }
For single transforms,
"data"is a dict:{"A": array, "b": array}For
TransformSequence,"data"is a tuple of dicts:({"A": array}, {"b": array})Deprecated since version Passing: parameters in direct format (without the
"data"/"err"wrapper) is deprecated and will be removed in a future version.names (
list[str] |str|None) – A single string or a list of output names to predict. IfNone, predict all outputs (default).
- Returns:
A dictionary of predicted output values, where the keys are the output names and values are arrays of shape
(n_objects, output_size).- Return type:
- register_output(name: str, data_transform: AbstractSingleTransform | TransformSequence, err_transform: AbstractSingleTransform | TransformSequence | None = None)#
Register a new output of the model given a specified transform.
- Parameters:
name (
str) – The name of the output. If you intend to use this model with numpyro and specified data, this name should correspond to the name of data passed in via a pollux.data.PolluxData object. The name cannot contain colons (‘:’) as they are reserved for internal parameter naming.data_transform (
AbstractSingleTransform|TransformSequence) – A specification of the transformation function that takes a latent vector representation in and predicts the output values.err_transform (
AbstractSingleTransform|TransformSequence|None)
- Return type:
- setup_numpyro(latents: Float[Array, '#stars latents'], data: PolluxData, names: list[str] | None = None)#
Sample parameters and set up basic numpyro model.
- Parameters:
latents (
Float[Array, '#stars latents']) – The latent vectors that transform into the outputs. In the case of the Paton, these are the (unknown) latent vectors. In the case of the Cannon, these are the observed latents for the training set (combinations of stellar labels).data (
PolluxData) – A dictionary-like object of observed data for each output. The keys should correspond to the output names.names (
list[str] |None) – A single string or a list of output names to set up. If None, set up all outputs (default).
- Returns:
A dictionary of sampled parameters for each output.
- Return type:
- unpack_numpyro_pars(pars: dict[str, Any], ignore_missing: bool = False)#
Unpack numpyro parameters into separate data and error parameter structures.
numpyro parameters use names like “output_name:param_name” to make the numpyro internal names unique. This method unpacks these into two nested dictionaries: one for data transform parameters and one for error transform parameters.
For TransformSequence outputs, data parameters are further unpacked from the flattened “{index}:{param}” format into a tuple of parameter dictionaries.
- Parameters:
- Returns:
A nested dictionary with keys as output names. Each output name is a key with a dict value containing “data” and “err” keys: - For single transforms, “data” values are parameter dictionaries - For TransformSequence, “data” values are tuples of parameter dictionaries - “err” values follow the same structure as “data” for the error transforms - “err” will be an empty dict {} if there are no error parameters - Non-output parameters (like “latents”) are passed through at the top level
Example structure: {
”flux”: {“data”: {…} or (…), “err”: {}}, # err empty if no error pars “label”: {“data”: {…}, “err”: {…}}, “latents”: array
}
- Return type:
- pollux.models.optimize_iterative(model: Lux, data: PolluxData, blocks: list[ParameterBlock] | list[str] | None = None, fixed_pars: dict[str, Any] | None = None, max_cycles: int = 100, tol: float = 0.0001, rng_key: Array | None = None, initial_params: dict[str, Any] | None = None, latents_prior: dist.Distribution | None = None, progress: bool = True, record_history: bool = False)#
Optimize model using iterative block coordinate descent.
This implements an alternating optimization strategy that cycles through parameter blocks, optimizing each while holding others fixed. For linear models, each sub-problem can be solved exactly using weighted least squares.
The default strategy alternates between: 1. Optimize latents (with output parameters fixed) 2. Optimize each output’s parameters (with latents and other outputs fixed)
- Parameters:
model (
Lux) – The Lux to optimize.data (
PolluxData) – The training data.blocks (
list[ParameterBlock] |list[str] |None) – List ofParameterBlockspecifications, or a list of strings naming which parameter groups to optimize (e.g.["latents"]). If strings are given,ParameterBlockinstances are constructed automatically with an inferred optimizer ("least_squares"for linear transforms). If None, uses a default strategy that alternates between latents and each output.fixed_pars (
dict[str,Any] |None) – Parameters to hold fixed during optimization. When provided alongside stringblocks, the function initializes latents to zero and mergesfixed_parswith the optimized parameters before returning, so the result contains a complete parameter dict. Ignored wheninitial_paramsis also provided (caller is responsible for merging in that case).max_cycles (
int) – Maximum number of full optimization cycles.tol (
float) – Convergence tolerance. Stops when relative change in loss < tol.rng_key (
Array|None) – JAX random key. Required when any block uses SVI (i.e.,optimizer != "least_squares") or wheninitial_paramsis None (used to sample initial values from the model priors; falls back tojax.random.PRNGKey(0)if not provided in that case).initial_params (
dict[str,Any] |None) – Initial parameter values. If None andfixed_parsis provided, built automatically by mergingfixed_parswith zero-initialized optimized params. If both are None, initialized from priors.latents_prior (
Distribution|None) – Prior distribution for latents. If None, uses Normal(0, 1). Used to determine regularization strength for latent least squares.progress (
bool) – Whether to display a tqdm progress bar showing optimization progress.record_history (
bool) – Whether to record detailed per-block loss history.
- Returns:
The optimization result containing optimized parameters and convergence info. When
fixed_parsis provided,result.paramsincludes both the fixed and optimized parameters.- Return type:
IterativeOptimizationResult
Notes
When a block has
optimizer=None, SVI is run withnumpyro.optim.Adamatstep_size=1e-3. Override viaoptimizer_kwargson the block, e.g.ParameterBlock(..., optimizer_kwargs={"step_size": 1e-4}).Examples
Basic usage with default blocks:
>>> result = optimize_iterative(model, data, max_cycles=20) >>> opt_params = result.params
Custom block specification:
>>> blocks = [ ... ParameterBlock("latents", "latents", optimizer="least_squares"), ... ParameterBlock("flux", "flux:data", optimizer="least_squares"), ... ParameterBlock("labels", "label:data", num_steps=500), ... ] >>> result = optimize_iterative(model, data, blocks=blocks)
Optimizing only latents with fixed output parameters (e.g. applying a trained model to new test data):
>>> result = optimize_iterative( ... model, test_data, blocks=["latents"], fixed_pars=trained_pars ... ) >>> test_opt_pars = result.params # already contains fixed + optimized
- class pollux.models.AbstractSingleTransform(output_size: int, transform: Callable[[...], Float[Array, 'output']], priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), vmap: bool = True, param_priors: Any = None, param_shapes: Any = None)#
Bases:
AbstractTransformBase class providing common functionality for atomic transforms.
“Single” transforms apply a single operation to convert latent vectors to outputs.
- Parameters:
output_size (
int) – Size of the output vector.priors (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Prior distributions for transform parameters.shapes (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Shape specifications for transform parameters.transform (
Callable[...,Float[Array, 'output']]) – The transform function. Should take latents as the first argument, followed by any parameters.vmap (
bool) – Whether to automatically vectorize the transform over the batch dimension. If True (default), the transform function should be written for a single sample (latents shape(latent_size,)), and JAX’svmapwill be applied to handle batches. Parameters are shared across all samples. If False, the transform function must handle batching itself. This is useful when parameters are per-sample (e.g., per-star nuisance parameters) or when the function has custom batching requirements.param_priors (
Any) – Deprecated. Usepriorsinstead.param_shapes (
Any) – Deprecated. Useshapesinstead.
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- class pollux.models.AbstractTransform(output_size: int)#
Bases:
ModuleBase class defining the transform interface.
Transforms convert latent vectors to observable quantities through parameterized functions. They define the mapping between latent space and output spaces.
- Parameters:
output_size (
int)
- abstractmethod apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Takes a batch of latent vectors and transforms them using the provided parameters to produce output values.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- abstractmethod get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- class pollux.models.AdditiveOffsetTransform(base_transform: ~pollux.models.transforms.AbstractSingleTransform | ~pollux.models.transforms.TransformSequence, offset_prior: ~numpyro.distributions.distribution.Distribution = <factory>)#
Bases:
ModuleTransform that wraps a base transform and adds a per-star scalar offset.
This transform is useful for modeling per-object nuisance parameters like distance modulus, where each object has its own offset that applies uniformly to all output dimensions. This is a generalization of the
AffineTransformandOffsetTransformclass, because here the offset can vary per object instead of per output.In other words, unlike
OffsetTransformwhich has a fixed offset vector of shape(output_size,), this transform samples a separate scalar offset for each object in the dataset, with shape(data_size,). The offset is then broadcast to all output dimensions.- Parameters:
base_transform (
AbstractSingleTransform|TransformSequence) – The underlying transform to wrap (e.g.,LinearTransform).offset_prior (
Distribution) – Prior distribution for the per-object offset. This will be expanded to shape(data_size,)during inference.
Examples
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> from pollux.models.transforms import AdditiveOffsetTransform, LinearTransform
Model apparent magnitudes as absolute magnitudes plus distance modulus:
>>> phot_trans = AdditiveOffsetTransform( ... base_transform=LinearTransform(output_size=3), # 3 photometric bands ... offset_prior=dist.Normal(11.0, 3.0), # Distance modulus prior ... )
The offset adapts to the data size automatically:
>>> import pollux as plx >>> model = plx.Lux(latent_size=8) >>> model.register_output("phot", phot_trans) >>> # During training with 1000 stars, offset has shape (1000,) >>> # During testing with 500 stars, offset has shape (500,)
Notes
The per-star offset is broadcast to all output dimensions, meaning the same offset value is added to every element of the output for a given object. This is appropriate for distance modulus (which shifts all magnitudes equally) but may not be appropriate for other use cases.
- apply(latents: Float[Array, '#stars latents'], **params: Any)#
Apply the base transform and add the per-star offset.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input latent vectors of shape(n_samples, latent_size).**params (
Any) – Parameters including base transform parameters (prefixed with “base:”) and the “offset” parameter of shape(n_samples,).
- Returns:
Output of shape
(n_samples, output_size).- Return type:
array
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors including the per-star offset.
- Parameters:
- Returns:
Dictionary of priors including base transform priors (prefixed with “base:”) and the offset prior with shape
(data_size,).- Return type:
ParamPriorsT
- Raises:
ValueError – If
data_sizeis None.
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack nested parameters into flat structure.
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Unpack flat parameters into nested structure.
-
base_transform:
AbstractSingleTransform|TransformSequence#
-
offset_prior:
Distribution#
- class pollux.models.AffineTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _affine_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab440 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformAffine transformation combining linear transform and offset.
Implements the transformation: y = A @ z + b, where A is a matrix, z is a latent vector, and b is a bias vector.
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab440 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
- transform(A: Float[Array, 'output latents'], b: Float[Array, 'output'])#
Apply an affine transformation.
Computes a linear transformation followed by an offset: A @ z + b.
- Parameters:
z (
Float[Array, 'latents'])A (
Float[Array, 'output latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
- class pollux.models.EquinoxNNTransform(output_size: int, nn_factory: ~typing.Any, weight_prior: ~numpyro.distributions.distribution.Distribution = <factory>, bias_prior: ~numpyro.distributions.distribution.Distribution = <factory>, priors: ~xmmutablemap._core.ImmutableMap[str, ~numpyro.distributions.distribution.Distribution] = <factory>, shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = <factory>, _param_paths: tuple[str, ...] = (), _template_nn: ~typing.Any = None)#
Bases:
AbstractTransformNeural network transform using an Equinox module.
This transform wraps an Equinox neural network module and exposes its parameters for Bayesian inference via numpyro. The network structure is defined by a factory function that creates the network given input size, output size, and a random key.
- Parameters:
output_size (
int) – The output dimension of the transform.nn_factory (
Any) – A callable that creates an Equinox module. It should have the signature:nn_factory(in_size: int, out_size: int, key: jax.Array) -> eqx.Moduleweight_prior (
Distribution) – Prior distribution for weight parameters. Default is Normal(0, 1).bias_prior (
Distribution) – Prior distribution for bias parameters. Default is Normal(0, 1).priors (ImmutableMap[str, Distribution])
_template_nn (Any)
Examples
>>> import jax >>> import equinox as eqx >>> import numpyro.distributions as dist >>> from pollux.models.transforms import EquinoxNNTransform
Create a simple MLP transform:
>>> nn_trans = EquinoxNNTransform( ... output_size=128, ... nn_factory=lambda in_size, out_size, key: eqx.nn.MLP( ... in_size=in_size, ... out_size=out_size, ... width_size=64, ... depth=2, ... key=key, ... ), ... weight_prior=dist.Normal(0, 0.1), ... bias_prior=dist.Normal(0, 0.01), ... )
Use with LuxModel:
>>> import pollux as plx >>> model = plx.LuxModel(latent_size=8) >>> model.register_output("flux", nn_trans)
- Parameters:
priors (
ImmutableMap[str,Distribution])_template_nn (
Any)output_size (int)
nn_factory (Any)
weight_prior (Distribution)
bias_prior (Distribution)
- apply(latents: Float[Array, '#stars latents'], **params: Any)#
Apply the neural network transform.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input latent vectors of shape (n_samples, latent_size).**params (
Any) – Neural network parameters, keyed by their path names.
- Returns:
Output of shape (n_samples, output_size).
- Return type:
array
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Create one prior per neural network parameter.
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters to flat format (for compatibility with TransformSequence).
- property param_priors: ImmutableMap[str, Distribution]#
Deprecated. Use
priorsinstead.
- property param_shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]]#
Deprecated. Use
shapesinstead.
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Unpack flat parameters (for compatibility with TransformSequence).
-
weight_prior:
Distribution#
-
bias_prior:
Distribution#
-
priors:
ImmutableMap[str,Distribution]#
- class pollux.models.FunctionTransform(output_size: int, transform: Callable[[...], Float[Array, 'output']], priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), vmap: bool = True, param_priors: Any = None, param_shapes: Any = None)#
Bases:
AbstractSingleTransformCustom transformation using a user-defined function.
This transform allows for arbitrary transformations defined by the user. It is particularly useful for modeling complex relationships or per-sample nuisance parameters.
- Parameters:
output_size (
int) – Size of the output vector.transform (
Callable[...,Float[Array, 'output']]) – The transform function. Should take latents as the first argument, followed by any parameters defined inpriors.priors (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Prior distributions for transform parameters. UseParamPriorsT(anImmutableMap[str, dist.Distribution]).shapes (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Shape specifications for transform parameters. UseParamShapesT(anImmutableMap[str, ShapeSpec | tuple[int, ...]]). UseShapeSpecwhen shapes depend onlatent_sizeordata_size.vmap (
bool) – Whether to automatically vectorize the transform over the batch dimension. Set to False when parameters are per-sample (e.g., per-star continuum corrections) and the function handles batching internally.param_priors (ImmutableMap[str, Distribution] | None)
param_shapes (ImmutableMap[str, ShapeSpec | tuple[int, ...]] | None)
Examples
Define a custom linear transform with learnable weights:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> from xmmutablemap import ImmutableMap >>> from pollux.models.transforms import FunctionTransform, ShapeSpec >>> >>> def my_transform(z, A): ... return jnp.dot(A, z) >>> >>> custom = FunctionTransform( ... output_size=128, ... transform=my_transform, ... priors=ImmutableMap({"A": dist.Normal(0, 1)}), ... shapes=ImmutableMap({"A": ShapeSpec(("output_size", "latent_size"))}), ... )
The parameter
Awill have shape(128, latent_size)wherelatent_sizeis determined when the transform is registered with a model.See also the “Inferring Continuum Model Parameters” tutorial for an example of using FunctionTransform with per-star parameters and
vmap=False.- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- class pollux.models.LinearTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _linear_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab230 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformLinear transformation from latent to output space.
Implements the transformation: y = A @ z, where A is a matrix and z is a latent vector.
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab230 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))})#
- transform(A: Float[Array, 'output latents'])#
Apply a linear transformation.
Computes the matrix product A @ z.
- Parameters:
z (
Float[Array, 'latents'])A (
Float[Array, 'output latents'])
- Return type:
Float[Array, 'output']
- class pollux.models.NoOpTransform(output_size: int = 0, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _noop_transform>, priors: ~xmmutablemap._core.ImmutableMap[str, ~numpyro.distributions.distribution.Distribution] = ImmutableMap({}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformNo-op transformation.
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- transform()#
No-op transformation.
- Parameters:
z (
Float[Array, 'latents'])- Return type:
Float[Array, 'output']
- class pollux.models.OffsetTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _offset_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac60 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformOffset transformation that adds a bias vector to inputs.
Implements the transformation: y = z + b, where b is a bias vector.
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac60 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))})#
- transform(b: Float[Array, 'output'])#
Apply an offset transformation.
Adds a bias vector b to the input: z + b.
- Parameters:
z (
Float[Array, 'latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
- class pollux.models.PolyFeatureTransform(output_size: int = 0, degree: int = 2, include_bias: bool = True, priors: ImmutableMap[str, Distribution] = ImmutableMap({}), shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]] = ImmutableMap({}))#
Bases:
AbstractTransformPolynomial feature expansion transform.
Expands input features into polynomial combinations up to the specified degree. This transform has NO learnable parameters - it’s a deterministic feature expansion.
This is useful for implementing The Cannon model, where labels are expanded into polynomial features before a linear transformation to predict spectra.
- Parameters:
Examples
>>> import jax.numpy as jnp >>> from pollux.models.transforms import PolyFeatureTransform, LinearTransform >>> from pollux.models.transforms import TransformSequence
Create a Cannon-style transform (polynomial features -> linear):
>>> cannon = TransformSequence(( ... PolyFeatureTransform(degree=2), ... LinearTransform(output_size=128), ... ))
The polynomial transform expands 3 labels into 10 features (with bias): - degree 0: 1 (bias) - degree 1: x1, x2, x3 - degree 2: x1^2, x1*x2, x1*x3, x2^2, x2*x3, x3^2
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **_pars: Any)#
Apply polynomial feature expansion.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input array of shape (n_samples, n_features).**_pars (
Any) – Ignored (no learnable parameters).
- Returns:
Polynomial features of shape (n_samples, n_poly_features).
- Return type:
array
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Return empty priors (no learnable parameters).
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(_nested_pars: dict[str, Any], _ignore_missing: bool = False)#
For compatibility with TransformSequence (returns empty dict).
- property param_priors: ImmutableMap[str, Distribution]#
Deprecated. Use
priorsinstead.
- property param_shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]]#
Deprecated. Use
shapesinstead.
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- class pollux.models.QuadraticTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _quadratic_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x78dd89aab6b0 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab200 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab710 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformQuadratic transformation of latent vectors.
Implements the transformation: y = z^T Q z + A @ z + b, where Q is a tensor, A is a matrix, z is a latent vector, and b is a bias vector.
- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x78dd89aab6b0 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab200 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab710 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
- transform(Q: Float[Array, 'output latents latents'], A: Float[Array, 'output latents'], b: Float[Array, 'output'])#
Apply a quadratic transformation.
Computes a quadratic form plus a linear term and an offset: z^T Q z + A @ z + b.
- Parameters:
z (
Float[Array, 'latents'])Q (
Float[Array, 'output latents latents'])A (
Float[Array, 'output latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
- class pollux.models.ShapeSpec(dims: tuple[str | int, ...])#
Bases:
objectSpecification for parameter shapes using named dimensions.
ShapeSpec allows you to define parameter shapes that depend on dimensions only known at model construction time (like latent_size). Named dimensions are resolved to concrete integers when the model is built.
- Parameters:
dims (
tuple[str|int,...]) – Tuple of dimension names (strings) or concrete sizes (integers).Dimensions (Available Named)
--------------------------
"output_size" – The output dimension of the transform.
"latent_size" – The latent space dimension (set when registering with a model).
"data_size" – The number of samples in the batch (useful for per-sample parameters).
"one" – Always resolves to 1 (useful for bias terms).
Examples
Define a weight matrix shape that depends on output and latent dimensions:
>>> from pollux.models.transforms import ShapeSpec >>> shape = ShapeSpec(("output_size", "latent_size")) >>> shape.resolve({"output_size": 128, "latent_size": 8}) (128, 8)
Define a bias vector shape using the special “one” dimension:
>>> bias_shape = ShapeSpec(("output_size", "one")) >>> bias_shape.resolve({"output_size": 128}) (128, 1)
Use with FunctionTransform to define custom transforms:
>>> from xmmutablemap import ImmutableMap >>> import numpyro.distributions as dist >>> shapes = ImmutableMap({"A": ShapeSpec(("output_size", "latent_size"))}) >>> priors = ImmutableMap({"A": dist.Normal(0, 1)})
- class pollux.models.TransformSequence(transforms: tuple[AbstractSingleTransform, ...])#
Bases:
AbstractTransformA sequence of transforms applied in order.
Composes multiple transforms together, where the output of each transform becomes the input to the next transform in the sequence.
Parameters are stored as tuples of dictionaries, one element per transform.
- Parameters:
transforms (
tuple[AbstractSingleTransform,...])
- apply(latents: Float[Array, '#stars latents'], *args: dict[str, Any], **kwargs: Any)#
Apply the sequence of transforms to input latent vectors.
Parameters can be provided in two ways: 1. As positional arguments: One dictionary per transform in sequence order 2. As keyword arguments: Using “{transform_index}:{param}” naming scheme, so a
parameter named “A” in transform 0 of the sequence would be “0:A”.
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors using flat naming scheme.
Returns flattened parameter priors with index-based naming for compatibility with the AbstractTransform interface. Parameter names will be in the format: “{transform_index}:{param_name}”
Note: For transform sequences, each transform’s “latent_size” is the output size of the previous transform (or the model’s latent_size for the first transform).
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: list[dict[str, Any]], ignore_missing: bool = False)#
Convert nested parameter structure to flat naming scheme.
Takes a list of parameter dictionaries and converts them to flat parameter names like “0:A”, “1:p1”.
- Parameters:
- Returns:
Dictionary with parameter names in format “{transform_index}:{param_name}”
- Return type:
- property param_priors: tuple[ImmutableMap[str, Distribution], ...]#
Deprecated. Use
priorsinstead.
- property param_shapes: tuple[ImmutableMap[str, ShapeSpec | tuple[int, ...]], ...]#
Deprecated. Use
shapesinstead.
- property priors: tuple[ImmutableMap[str, Distribution], ...]#
Collect parameter priors from all transforms in the sequence.
- property shapes: tuple[ImmutableMap[str, ShapeSpec | tuple[int, ...]], ...]#
Collect parameter shapes from all transforms in the sequence.
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Convert flat parameter names to nested tuple structure.
Takes parameters with names like “0:A”, “1:p1” and converts them to a list of parameter dictionaries: [{“A”: value}, {“p1”: value}]
-
transforms:
tuple[AbstractSingleTransform,...]#
pollux.models.transforms#
Transforms for mapping latent vectors to output quantities.
- class pollux.models.transforms.AbstractSingleTransform(output_size: int, transform: Callable[[...], Float[Array, 'output']], priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), vmap: bool = True, param_priors: Any = None, param_shapes: Any = None)#
Bases:
AbstractTransformBase class providing common functionality for atomic transforms.
“Single” transforms apply a single operation to convert latent vectors to outputs.
- Parameters:
output_size (
int) – Size of the output vector.priors (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Prior distributions for transform parameters.shapes (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Shape specifications for transform parameters.transform (
Callable[...,Float[Array, 'output']]) – The transform function. Should take latents as the first argument, followed by any parameters.vmap (
bool) – Whether to automatically vectorize the transform over the batch dimension. If True (default), the transform function should be written for a single sample (latents shape(latent_size,)), and JAX’svmapwill be applied to handle batches. Parameters are shared across all samples. If False, the transform function must handle batching itself. This is useful when parameters are per-sample (e.g., per-star nuisance parameters) or when the function has custom batching requirements.param_priors (
Any) – Deprecated. Usepriorsinstead.param_shapes (
Any) – Deprecated. Useshapesinstead.
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Unpack parameters (identity for single transforms).
- class pollux.models.transforms.AbstractTransform(output_size: int)#
Bases:
ModuleBase class defining the transform interface.
Transforms convert latent vectors to observable quantities through parameterized functions. They define the mapping between latent space and output spaces.
- Parameters:
output_size (
int)
- abstractmethod apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Takes a batch of latent vectors and transforms them using the provided parameters to produce output values.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- abstractmethod get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- class pollux.models.transforms.AdditiveOffsetTransform(base_transform: ~pollux.models.transforms.AbstractSingleTransform | ~pollux.models.transforms.TransformSequence, offset_prior: ~numpyro.distributions.distribution.Distribution = <factory>)#
Bases:
ModuleTransform that wraps a base transform and adds a per-star scalar offset.
This transform is useful for modeling per-object nuisance parameters like distance modulus, where each object has its own offset that applies uniformly to all output dimensions. This is a generalization of the
AffineTransformandOffsetTransformclass, because here the offset can vary per object instead of per output.In other words, unlike
OffsetTransformwhich has a fixed offset vector of shape(output_size,), this transform samples a separate scalar offset for each object in the dataset, with shape(data_size,). The offset is then broadcast to all output dimensions.- Parameters:
base_transform (
AbstractSingleTransform|TransformSequence) – The underlying transform to wrap (e.g.,LinearTransform).offset_prior (
Distribution) – Prior distribution for the per-object offset. This will be expanded to shape(data_size,)during inference.
Examples
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> from pollux.models.transforms import AdditiveOffsetTransform, LinearTransform
Model apparent magnitudes as absolute magnitudes plus distance modulus:
>>> phot_trans = AdditiveOffsetTransform( ... base_transform=LinearTransform(output_size=3), # 3 photometric bands ... offset_prior=dist.Normal(11.0, 3.0), # Distance modulus prior ... )
The offset adapts to the data size automatically:
>>> import pollux as plx >>> model = plx.Lux(latent_size=8) >>> model.register_output("phot", phot_trans) >>> # During training with 1000 stars, offset has shape (1000,) >>> # During testing with 500 stars, offset has shape (500,)
Notes
The per-star offset is broadcast to all output dimensions, meaning the same offset value is added to every element of the output for a given object. This is appropriate for distance modulus (which shifts all magnitudes equally) but may not be appropriate for other use cases.
-
base_transform:
AbstractSingleTransform|TransformSequence#
-
offset_prior:
Distribution#
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors including the per-star offset.
- Parameters:
- Returns:
Dictionary of priors including base transform priors (prefixed with “base:”) and the offset prior with shape
(data_size,).- Return type:
ParamPriorsT
- Raises:
ValueError – If
data_sizeis None.
- apply(latents: Float[Array, '#stars latents'], **params: Any)#
Apply the base transform and add the per-star offset.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input latent vectors of shape(n_samples, latent_size).**params (
Any) – Parameters including base transform parameters (prefixed with “base:”) and the “offset” parameter of shape(n_samples,).
- Returns:
Output of shape
(n_samples, output_size).- Return type:
array
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Unpack flat parameters into nested structure.
- class pollux.models.transforms.AffineTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _affine_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab440 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformAffine transformation combining linear transform and offset.
Implements the transformation: y = A @ z + b, where A is a matrix, z is a latent vector, and b is a bias vector.
- Parameters:
- transform(A: Float[Array, 'output latents'], b: Float[Array, 'output'])#
Apply an affine transformation.
Computes a linear transformation followed by an offset: A @ z + b.
- Parameters:
z (
Float[Array, 'latents'])A (
Float[Array, 'output latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab440 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- class pollux.models.transforms.EquinoxNNTransform(output_size: int, nn_factory: ~typing.Any, weight_prior: ~numpyro.distributions.distribution.Distribution = <factory>, bias_prior: ~numpyro.distributions.distribution.Distribution = <factory>, priors: ~xmmutablemap._core.ImmutableMap[str, ~numpyro.distributions.distribution.Distribution] = <factory>, shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = <factory>, _param_paths: tuple[str, ...] = (), _template_nn: ~typing.Any = None)#
Bases:
AbstractTransformNeural network transform using an Equinox module.
This transform wraps an Equinox neural network module and exposes its parameters for Bayesian inference via numpyro. The network structure is defined by a factory function that creates the network given input size, output size, and a random key.
- Parameters:
output_size (
int) – The output dimension of the transform.nn_factory (
Any) – A callable that creates an Equinox module. It should have the signature:nn_factory(in_size: int, out_size: int, key: jax.Array) -> eqx.Moduleweight_prior (
Distribution) – Prior distribution for weight parameters. Default is Normal(0, 1).bias_prior (
Distribution) – Prior distribution for bias parameters. Default is Normal(0, 1).priors (ImmutableMap[str, Distribution])
_template_nn (Any)
Examples
>>> import jax >>> import equinox as eqx >>> import numpyro.distributions as dist >>> from pollux.models.transforms import EquinoxNNTransform
Create a simple MLP transform:
>>> nn_trans = EquinoxNNTransform( ... output_size=128, ... nn_factory=lambda in_size, out_size, key: eqx.nn.MLP( ... in_size=in_size, ... out_size=out_size, ... width_size=64, ... depth=2, ... key=key, ... ), ... weight_prior=dist.Normal(0, 0.1), ... bias_prior=dist.Normal(0, 0.01), ... )
Use with LuxModel:
>>> import pollux as plx >>> model = plx.LuxModel(latent_size=8) >>> model.register_output("flux", nn_trans)
- Parameters:
priors (
ImmutableMap[str,Distribution])_template_nn (
Any)output_size (int)
nn_factory (Any)
weight_prior (Distribution)
bias_prior (Distribution)
-
weight_prior:
Distribution#
-
bias_prior:
Distribution#
-
priors:
ImmutableMap[str,Distribution]#
- property param_priors: ImmutableMap[str, Distribution]#
Deprecated. Use
priorsinstead.
- property param_shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]]#
Deprecated. Use
shapesinstead.
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Create one prior per neural network parameter.
- apply(latents: Float[Array, '#stars latents'], **params: Any)#
Apply the neural network transform.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input latent vectors of shape (n_samples, latent_size).**params (
Any) – Neural network parameters, keyed by their path names.
- Returns:
Output of shape (n_samples, output_size).
- Return type:
array
- class pollux.models.transforms.FunctionTransform(output_size: int, transform: Callable[[...], Float[Array, 'output']], priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]] = ImmutableMap({}), vmap: bool = True, param_priors: Any = None, param_shapes: Any = None)#
Bases:
AbstractSingleTransformCustom transformation using a user-defined function.
This transform allows for arbitrary transformations defined by the user. It is particularly useful for modeling complex relationships or per-sample nuisance parameters.
- Parameters:
output_size (
int) – Size of the output vector.transform (
Callable[...,Float[Array, 'output']]) – The transform function. Should take latents as the first argument, followed by any parameters defined inpriors.priors (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Prior distributions for transform parameters. UseParamPriorsT(anImmutableMap[str, dist.Distribution]).shapes (
Mapping[TypeVar(K),TypeVar(V)] |tuple[TypeVar(K),TypeVar(V)] |Iterable[tuple[TypeVar(K),TypeVar(V)]]) – Shape specifications for transform parameters. UseParamShapesT(anImmutableMap[str, ShapeSpec | tuple[int, ...]]). UseShapeSpecwhen shapes depend onlatent_sizeordata_size.vmap (
bool) – Whether to automatically vectorize the transform over the batch dimension. Set to False when parameters are per-sample (e.g., per-star continuum corrections) and the function handles batching internally.param_priors (ImmutableMap[str, Distribution] | None)
param_shapes (ImmutableMap[str, ShapeSpec | tuple[int, ...]] | None)
Examples
Define a custom linear transform with learnable weights:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> from xmmutablemap import ImmutableMap >>> from pollux.models.transforms import FunctionTransform, ShapeSpec >>> >>> def my_transform(z, A): ... return jnp.dot(A, z) >>> >>> custom = FunctionTransform( ... output_size=128, ... transform=my_transform, ... priors=ImmutableMap({"A": dist.Normal(0, 1)}), ... shapes=ImmutableMap({"A": ShapeSpec(("output_size", "latent_size"))}), ... )
The parameter
Awill have shape(128, latent_size)wherelatent_sizeis determined when the transform is registered with a model.See also the “Inferring Continuum Model Parameters” tutorial for an example of using FunctionTransform with per-star parameters and
vmap=False.- Parameters:
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- class pollux.models.transforms.LinearTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _linear_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab230 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformLinear transformation from latent to output space.
Implements the transformation: y = A @ z, where A is a matrix and z is a latent vector.
- Parameters:
- transform(A: Float[Array, 'output latents'])#
Apply a linear transformation.
Computes the matrix product A @ z.
- Parameters:
z (
Float[Array, 'latents'])A (
Float[Array, 'output latents'])
- Return type:
Float[Array, 'output']
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab230 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))})#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- class pollux.models.transforms.NoOpTransform(output_size: int = 0, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _noop_transform>, priors: ~xmmutablemap._core.ImmutableMap[str, ~numpyro.distributions.distribution.Distribution] = ImmutableMap({}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformNo-op transformation.
- Parameters:
- transform()#
No-op transformation.
- Parameters:
z (
Float[Array, 'latents'])- Return type:
Float[Array, 'output']
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- class pollux.models.transforms.OffsetTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _offset_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac60 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformOffset transformation that adds a bias vector to inputs.
Implements the transformation: y = z + b, where b is a bias vector.
- Parameters:
- transform(b: Float[Array, 'output'])#
Apply an offset transformation.
Adds a bias vector b to the input: z + b.
- Parameters:
z (
Float[Array, 'latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aaac60 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))})#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- pollux.models.transforms.ParamPriorsT#
maps parameter names to distributions. Used with
FunctionTransformto specify priors for learnable parameters.- Type:
Type alias for parameter priors
alias of
ImmutableMap[str,Distribution]
- pollux.models.transforms.ParamShapesT#
maps parameter names to shape specifications. Shapes can be
ShapeSpec(for named dimensions) or concrete tuples.- Type:
Type alias for parameter shapes
- class pollux.models.transforms.PolyFeatureTransform(output_size: int = 0, degree: int = 2, include_bias: bool = True, priors: ImmutableMap[str, Distribution] = ImmutableMap({}), shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]] = ImmutableMap({}))#
Bases:
AbstractTransformPolynomial feature expansion transform.
Expands input features into polynomial combinations up to the specified degree. This transform has NO learnable parameters - it’s a deterministic feature expansion.
This is useful for implementing The Cannon model, where labels are expanded into polynomial features before a linear transformation to predict spectra.
- Parameters:
Examples
>>> import jax.numpy as jnp >>> from pollux.models.transforms import PolyFeatureTransform, LinearTransform >>> from pollux.models.transforms import TransformSequence
Create a Cannon-style transform (polynomial features -> linear):
>>> cannon = TransformSequence(( ... PolyFeatureTransform(degree=2), ... LinearTransform(output_size=128), ... ))
The polynomial transform expands 3 labels into 10 features (with bias): - degree 0: 1 (bias) - degree 1: x1, x2, x3 - degree 2: x1^2, x1*x2, x1*x3, x2^2, x2*x3, x3^2
- Parameters:
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({})#
- property param_priors: ImmutableMap[str, Distribution]#
Deprecated. Use
priorsinstead.
- property param_shapes: ImmutableMap[str, ShapeSpec | tuple[int, ...]]#
Deprecated. Use
shapesinstead.
- apply(latents: Float[Array, '#stars latents'], **_pars: Any)#
Apply polynomial feature expansion.
- Parameters:
latents (
Float[Array, '#stars latents']) – Input array of shape (n_samples, n_features).**_pars (
Any) – Ignored (no learnable parameters).
- Returns:
Polynomial features of shape (n_samples, n_poly_features).
- Return type:
array
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Return empty priors (no learnable parameters).
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- class pollux.models.transforms.QuadraticTransform(output_size: int, transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _quadratic_transform>, priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x78dd89aab6b0 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab200 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab710 with batch shape () and event shape ()>}), shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux.models.transforms.ShapeSpec | tuple[int, ...]] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), vmap: bool = True, param_priors: ~typing.Any = None, param_shapes: ~typing.Any = None)#
Bases:
AbstractSingleTransformQuadratic transformation of latent vectors.
Implements the transformation: y = z^T Q z + A @ z + b, where Q is a tensor, A is a matrix, z is a latent vector, and b is a bias vector.
- Parameters:
- transform(Q: Float[Array, 'output latents latents'], A: Float[Array, 'output latents'], b: Float[Array, 'output'])#
Apply a quadratic transformation.
Computes a quadratic form plus a linear term and an offset: z^T Q z + A @ z + b.
- Parameters:
z (
Float[Array, 'latents'])Q (
Float[Array, 'output latents latents'])A (
Float[Array, 'output latents'])b (
Float[Array, 'output'])
- Return type:
Float[Array, 'output']
-
priors:
ImmutableMap[str,Distribution] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x78dd89aab6b0 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x78dd89aab200 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x78dd89aab710 with batch shape () and event shape ()>})#
-
shapes:
ImmutableMap[str,ShapeSpec|tuple[int,...]] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
- apply(latents: Float[Array, '#stars latents'], **pars: Any)#
Apply the transform to input latent vectors.
Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.
- Parameters:
latents (
Float[Array, '#stars latents'])pars (
Any)
- Return type:
Float[Array, '#stars output']
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors.
Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.
- Parameters:
- Return type:
ImmutableMap[str,Distribution]
- pack_pars(nested_pars: dict[str, Any], ignore_missing: bool = False)#
Pack parameters (identity for single transforms).
-
param_priors:
ImmutableMap[str,Distribution] |None= None#
- class pollux.models.transforms.ShapeSpec(dims: tuple[str | int, ...])#
Bases:
objectSpecification for parameter shapes using named dimensions.
ShapeSpec allows you to define parameter shapes that depend on dimensions only known at model construction time (like latent_size). Named dimensions are resolved to concrete integers when the model is built.
- Parameters:
dims (
tuple[str|int,...]) – Tuple of dimension names (strings) or concrete sizes (integers).Dimensions (Available Named)
--------------------------
"output_size" – The output dimension of the transform.
"latent_size" – The latent space dimension (set when registering with a model).
"data_size" – The number of samples in the batch (useful for per-sample parameters).
"one" – Always resolves to 1 (useful for bias terms).
Examples
Define a weight matrix shape that depends on output and latent dimensions:
>>> from pollux.models.transforms import ShapeSpec >>> shape = ShapeSpec(("output_size", "latent_size")) >>> shape.resolve({"output_size": 128, "latent_size": 8}) (128, 8)
Define a bias vector shape using the special “one” dimension:
>>> bias_shape = ShapeSpec(("output_size", "one")) >>> bias_shape.resolve({"output_size": 128}) (128, 1)
Use with FunctionTransform to define custom transforms:
>>> from xmmutablemap import ImmutableMap >>> import numpyro.distributions as dist >>> shapes = ImmutableMap({"A": ShapeSpec(("output_size", "latent_size"))}) >>> priors = ImmutableMap({"A": dist.Normal(0, 1)})
- class pollux.models.transforms.TransformSequence(transforms: tuple[AbstractSingleTransform, ...])#
Bases:
AbstractTransformA sequence of transforms applied in order.
Composes multiple transforms together, where the output of each transform becomes the input to the next transform in the sequence.
Parameters are stored as tuples of dictionaries, one element per transform.
- Parameters:
transforms (
tuple[AbstractSingleTransform,...])
-
transforms:
tuple[AbstractSingleTransform,...]#
- property priors: tuple[ImmutableMap[str, Distribution], ...]#
Collect parameter priors from all transforms in the sequence.
- property param_priors: tuple[ImmutableMap[str, Distribution], ...]#
Deprecated. Use
priorsinstead.
- property shapes: tuple[ImmutableMap[str, ShapeSpec | tuple[int, ...]], ...]#
Collect parameter shapes from all transforms in the sequence.
- property param_shapes: tuple[ImmutableMap[str, ShapeSpec | tuple[int, ...]], ...]#
Deprecated. Use
shapesinstead.
- apply(latents: Float[Array, '#stars latents'], *args: dict[str, Any], **kwargs: Any)#
Apply the sequence of transforms to input latent vectors.
Parameters can be provided in two ways: 1. As positional arguments: One dictionary per transform in sequence order 2. As keyword arguments: Using “{transform_index}:{param}” naming scheme, so a
parameter named “A” in transform 0 of the sequence would be “0:A”.
- unpack_pars(flat_pars: dict[str, Any], ignore_missing: bool = False)#
Convert flat parameter names to nested tuple structure.
Takes parameters with names like “0:A”, “1:p1” and converts them to a list of parameter dictionaries: [{“A”: value}, {“p1”: value}]
- pack_pars(nested_pars: list[dict[str, Any]], ignore_missing: bool = False)#
Convert nested parameter structure to flat naming scheme.
Takes a list of parameter dictionaries and converts them to flat parameter names like “0:A”, “1:p1”.
- Parameters:
- Returns:
Dictionary with parameter names in format “{transform_index}:{param_name}”
- Return type:
- get_expanded_priors(latent_size: int, data_size: int | None = None)#
Get expanded parameter priors using flat naming scheme.
Returns flattened parameter priors with index-based naming for compatibility with the AbstractTransform interface. Parameter names will be in the format: “{transform_index}:{param_name}”
Note: For transform sequences, each transform’s “latent_size” is the output size of the previous transform (or the model’s latent_size for the first transform).
- Parameters:
- Return type:
ImmutableMap[str,Distribution]