pollux.models

Contents

pollux.models#

Root module#

class pollux.models.AbstractTransform(output_size: int, param_priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]], param_shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]])#

Bases: Module

Base class defining the transform interface.

Transforms convert latent vectors to observable quantities through parameterized functions. They define the mapping between latent space and output spaces.

Parameters:
abstractmethod apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Takes a batch of latent vectors and transforms them using the provided parameters to produce output values.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

abstractmethod get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

output_size: int#
param_priors: ImmutableMap[str, Distribution]#
param_shapes: ImmutableMap[str, ShapeSpec]#
class pollux.models.AffineTransform(output_size: int, param_priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x72d4a23c8b30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x72d4a23c89b0 with batch shape () and event shape ()>}), param_shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux._src.models.transforms.ShapeSpec] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _affine_transform>, vmap: bool = True)#

Bases: AbstractAtomicTransform

Affine transformation combining linear transform and offset.

Implements the transformation: y = A @ z + b, where A is a matrix, z is a latent vector, and b is a bias vector.

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

param_priors: ImmutableMap[str, Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x72d4a23c8b30 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x72d4a23c89b0 with batch shape () and event shape ()>})#
param_shapes: ImmutableMap[str, ShapeSpec] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
transform(A: Float[Array, 'output latents'], b: Float[Array, 'output'])#

Apply an affine transformation.

Computes a linear transformation followed by an offset: A @ z + b.

Parameters:
  • z (Float[Array, 'latents'])

  • A (Float[Array, 'output latents'])

  • b (Float[Array, 'output'])

Return type:

Float[Array, 'output']

vmap: bool = True#
output_size: int#
class pollux.models.FunctionTransform(output_size: int, param_priors: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]], param_shapes: Mapping[K, V] | tuple[K, V] | Iterable[tuple[K, V]], transform: Callable[[...], Float[Array, 'output']], vmap: bool = True)#

Bases: AbstractAtomicTransform

Function transformation using a user-defined function.

This transform allows for arbitrary transformations defined by the user.

Examples

TODO: add in quadrature

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

vmap: bool = True#
transform: Callable[..., Float[Array, 'output']]#
output_size: int#
param_priors: ImmutableMap[str, Distribution]#
param_shapes: ImmutableMap[str, ShapeSpec]#
class pollux.models.LinearTransform(output_size: int, param_priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x72d4a2481b20 with batch shape () and event shape ()>}), param_shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux._src.models.transforms.ShapeSpec] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))}), transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _linear_transform>, vmap: bool = True)#

Bases: AbstractAtomicTransform

Linear transformation from latent to output space.

Implements the transformation: y = A @ z, where A is a matrix and z is a latent vector.

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

param_priors: ImmutableMap[str, Distribution] = ImmutableMap({'A': <numpyro.distributions.continuous.Normal object at 0x72d4a2481b20 with batch shape () and event shape ()>})#
param_shapes: ImmutableMap[str, ShapeSpec] = ImmutableMap({'A': ShapeSpec(dims=('output_size', 'latent_size'))})#
transform(A: Float[Array, 'output latents'])#

Apply a linear transformation.

Computes the matrix product A @ z.

Parameters:
  • z (Float[Array, 'latents'])

  • A (Float[Array, 'output latents'])

Return type:

Float[Array, 'output']

vmap: bool = True#
output_size: int#
class pollux.models.LuxModel(latent_size: int)#

Bases: Module

A latent variable model with multiple outputs.

A Pollux model is a generative, latent variable model for output data. This is a general framework for constructing multi-output or multi-task models in which the output data is generated as a transformation away from some embedded vector representation of each object. While this class and model structure can be used in a broad range of applications, this package and implementation was written with applications to stellar spectroscopic data in mind.

Parameters:

latent_size (int) – The size of the latent vector representation for each object (i.e. the embedded dimensionality).

default_numpyro_model(data: PolluxData, latent_prior: Distribution | None | bool = None, fixed_params: dict[str, Array] | None = None, names: list[str] | None = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None)#

Create the default numpyro model for this Lux model.

The default model uses the specified latent vector prior and assumes that the data are Gaussian distributed away from the true (predicted) values given the specified errors.

Parameters:
  • data (PolluxData) – A dictionary of observed data.

  • latent_prior (Distribution | None | bool) – The prior distribution for the latent vectors. If not specified, use a unit Gaussian. If False, use an improper uniform prior.

  • fixed_params (dict[str, Array] | None) – A dictionary of fixed parameters to condition on. If None, all parameters will be sampled.

  • names (list[str] | None) – A list of output names to include in the model. If None, include all outputs.

  • custom_model (Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None) – Optional callable that takes latents, params, and data and adds custom modeling components.

Return type:

None

optimize(data: PolluxData, num_steps: int, rng_key: Array, optimizer: _NumPyroOptim | Optimizer | Any | None = None, latent_prior: Distribution | None | bool = None, custom_model: Callable[[Float[Array, '#stars latents'], dict[str, Any], PolluxData], None] | None = None, fixed_params: dict[str, dict[str, Array] | Array] | None = None, names: list[str] | None = None, svi_run_kwargs: dict[str, Any] | None = None)#

Optimize the model parameters.

Parameters:
Return type:

tuple[dict[str, dict[str, Array] | Array], Any]

pack_numpyro_params(params: dict[str, dict[str, Array] | Array])#

Pack parameters into a flat dictionary keyed on numpyro names.

This method is the inverse of unpack_numpyro_params. It takes a nested dictionary of parameters and flattens it into a dictionary of numpyro parameters. For example, it takes a nested dictionary keyed like [output_name][param_name] and flattens it into a dictionary keyed like “output_name:param_name”.

Parameters:

params (dict[str, dict[str, Array] | Array]) – A nested dictionary of parameters, where the top level keys are the output names.

Returns:

A dictionary of numpyro parameters. The keys are in the format “output_name:param_name

Return type:

dict

predict_outputs(latents: Float[Array, '#stars latents'], params: dict[str, Any], names: list[str] | str | None = None)#

Predict output values for given latent vectors and parameters.

Parameters:
  • latents (Float[Array, '#stars latents']) – The latent vectors that transform into the outputs.

  • params (dict[str, Any]) – A dictionary of parameters for each output transformation in the model.

  • names (list[str] | str | None) – A single string or a list of output names to predict. If None, predict all outputs (default).

Returns:

A dictionary of predicted output values, where the keys are the output names.

Return type:

dict

register_output(name: str, data_transform: AbstractTransform, err_transform: AbstractTransform | None = None)#

Register a new output of the model given a specified transform.

Parameters:
  • name (str) – The name of the output. If you intend to use this model with numpyro and specified data, this name should correspond to the name of data passed in via a pollux.data.PolluxData object.

  • data_transform (AbstractTransform) – A specification of the transformation function that takes a latent vector representation in and predicts the output values.

  • err_transform (AbstractTransform | None)

Return type:

None

setup_numpyro(latents: Float[Array, '#stars latents'], data: PolluxData, names: list[str] | None = None)#

Sample parameters and set up basic numpyro model.

Parameters:
  • latents (Float[Array, '#stars latents']) – The latent vectors that transform into the outputs. In the case of the Paton, these are the (unknown) latent vectors. In the case of the Cannon, these are the observed latents for the training set (combinations of stellar labels).

  • data (PolluxData) – A dictionary-like object of observed data for each output. The keys should correspond to the output names.

  • names (list[str] | None) – A single string or a list of output names to set up. If None, set up all outputs (default).

Returns:

A dictionary of sampled parameters for each output.

Return type:

dict

unpack_numpyro_params(params: dict[str, Array], skip_missing: bool = False)#

Unpack numpyro parameters into a nested structure.

numpyro parameters use names like “output_name:param_name” to make the numpyro internal names unique. However, this method unpacks these into a nested dictionary keyed on [output_name][param_name].

Parameters:
  • params (dict[str, Array]) – A dictionary of numpyro parameters. The keys should be in the format “output_name:param_name”.

  • skip_missing (bool)

Returns:

A nested dictionary of parameters, where the top level keys are the output names.

Return type:

dict

latent_size: int#
outputs: dict[str, LuxOutput]#
class pollux.models.NoOpTransform(output_size: int = 0, param_priors: ~xmmutablemap._core.ImmutableMap[str, ~numpyro.distributions.distribution.Distribution] = ImmutableMap({}), param_shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux._src.models.transforms.ShapeSpec] = ImmutableMap({}), transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _noop_transform>, vmap: bool = True)#

Bases: AbstractAtomicTransform

No-op transformation.

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

output_size: int = 0#
param_priors: ImmutableMap[str, Distribution] = ImmutableMap({})#
param_shapes: ImmutableMap[str, ShapeSpec] = ImmutableMap({})#
transform()#

No-op transformation.

Parameters:

z (Float[Array, 'latents'])

Return type:

Float[Array, 'output']

vmap: bool = True#
class pollux.models.OffsetTransform(output_size: int, param_priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x72d4a23c8890 with batch shape () and event shape ()>}), param_shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux._src.models.transforms.ShapeSpec] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))}), transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _offset_transform>, vmap: bool = True)#

Bases: AbstractAtomicTransform

Offset transformation that adds a bias vector to inputs.

Implements the transformation: y = z + b, where b is a bias vector.

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

param_priors: ImmutableMap[str, Distribution] = ImmutableMap({'b': <numpyro.distributions.continuous.Normal object at 0x72d4a23c8890 with batch shape () and event shape ()>})#
param_shapes: ImmutableMap[str, ShapeSpec] = ImmutableMap({'b': ShapeSpec(dims=('output_size', 'one'))})#
transform(b: Float[Array, 'output'])#

Apply an offset transformation.

Adds a bias vector b to the input: z + b.

Parameters:
  • z (Float[Array, 'latents'])

  • b (Float[Array, 'output'])

Return type:

Float[Array, 'output']

vmap: bool = True#
output_size: int#
class pollux.models.QuadraticTransform(output_size: int, param_priors: ~collections.abc.Mapping[~xmmutablemap._core.K, ~xmmutablemap._core.V] | tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V] | ~collections.abc.Iterable[tuple[~xmmutablemap._core.K, ~xmmutablemap._core.V]] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x72d4a235ff20 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x72d4a235fef0 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x72d4a235f320 with batch shape () and event shape ()>}), param_shapes: ~xmmutablemap._core.ImmutableMap[str, ~pollux._src.models.transforms.ShapeSpec] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))}), transform: ~collections.abc.Callable[[...], ~jaxtyping.Float[Array, 'output']] = <function _quadratic_transform>, vmap: bool = True)#

Bases: AbstractAtomicTransform

Quadratic transformation of latent vectors.

Implements the transformation: y = z^T Q z + A @ z + b, where Q is a tensor, A is a matrix, z is a latent vector, and b is a bias vector.

Parameters:
apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the transform to input latent vectors.

Extracts the required parameters from the kwargs and applies the transform function to the latents, handling vectorization automatically.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

param_priors: ImmutableMap[str, Distribution] = ImmutableMap({'Q': <numpyro.distributions.continuous.Normal object at 0x72d4a235ff20 with batch shape () and event shape ()>, 'A': <numpyro.distributions.continuous.Normal object at 0x72d4a235fef0 with batch shape () and event shape ()>, 'b': <numpyro.distributions.continuous.Normal object at 0x72d4a235f320 with batch shape () and event shape ()>})#
param_shapes: ImmutableMap[str, ShapeSpec] = ImmutableMap({'Q': ShapeSpec(dims=('output_size', 'latent_size', 'latent_size')), 'A': ShapeSpec(dims=('output_size', 'latent_size')), 'b': ShapeSpec(dims=('output_size', 'one'))})#
transform(Q: Float[Array, 'output latents latents'], A: Float[Array, 'output latents'], b: Float[Array, 'output'])#

Apply a quadratic transformation.

Computes a quadratic form plus a linear term and an offset: z^T Q z + A @ z + b.

Parameters:
  • z (Float[Array, 'latents'])

  • Q (Float[Array, 'output latents latents'])

  • A (Float[Array, 'output latents'])

  • b (Float[Array, 'output'])

Return type:

Float[Array, 'output']

vmap: bool = True#
output_size: int#
class pollux.models.TransformSequence(transforms: tuple[AbstractTransform, ...])#

Bases: AbstractTransform

A sequence of transforms applied in order.

Composes multiple transforms together, where the output of each transform becomes the input to the next transform in the sequence.

Parameters:

transforms (tuple[AbstractTransform, ...])

apply(latents: Float[Array, '#stars latents'], **params: Any)#

Apply the sequence of transforms to input latent vectors.

Passes the input through each transform in sequence, routing the appropriate parameters to each transform based on the prefixed parameter names.

Parameters:
  • latents (Float[Array, '#stars latents'])

  • params (Any)

Return type:

Float[Array, '#stars output']

get_priors(latent_size: int, data_size: int | None = None)#

Get expanded parameter priors.

Expands the parameter prior distributions to the concrete shapes needed for the transform, based on latent size and optional data size.

Parameters:
Return type:

ImmutableMap[str, Distribution]

transforms: tuple[AbstractTransform, ...]#
output_size: int#
param_priors: ImmutableMap[str, Distribution]#
param_shapes: ImmutableMap[str, ShapeSpec]#

pollux.models.transforms#