pollux.data#

Root module#

class pollux.data.AbstractPreprocessor#

Bases: Module

Base class for data preprocessors.

abstractmethod classmethod from_data(data: Float[Array, '#stars output'])#

Compute preprocessing parameters from data.

Parameters:

data (Float[Array, '#stars output'])

Return type:

AbstractPreprocessor

abstractmethod inverse_transform(X: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

abstractmethod inverse_transform_err(X_err: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

abstractmethod transform(X: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

abstractmethod transform_err(X_err: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

class pollux.data.NullPreprocessor#

Bases: AbstractPreprocessor

A preprocessor that does nothing to the input data.

Examples

>>> import jax.numpy as jnp
>>> import jax.random as jrnd
>>> from pollux.data import NullPreprocessor
>>> data = jrnd.normal(jrnd.PRNGKey(0), shape=(1024, 10))
>>> preprocessor = NullPreprocessor()
>>> new_data = preprocessor.transform(data)
>>> assert jnp.all(new_data == data)
classmethod from_data(*_: Any)#

Compute preprocessing parameters from data.

Parameters:

_ (Any)

Return type:

NullPreprocessor

inverse_transform(X: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

inverse_transform_err(X_err: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

transform(X: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

transform_err(X_err: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

class pollux.data.OutputData(data: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, err: ~typing.Any = <factory>, preprocessor: ~pollux.data.preprocessor.AbstractPreprocessor = NullPreprocessor(), processed: bool = False)#

Bases: Module

A container for single block of output data.

This class is used to store data for a single output of a model, such as fluxes for a collection of stars, or stellar labels for a set of stars, or other data like broadband magnitudes, etc. Each instance of this class should correspond to a single output data type (e.g., spectral fluxes should be a separate instance from stellar labels).

Parameters:
  • data (array-like) – The output data.

  • err (array-like, optional) – The uncertainties (errors) of the output data.

  • preprocessor (AbstractPreprocessor, optional) – A preprocessor to apply to the data. For example, this might recenter the data on the mean and scale to unit variance (using the NormalizePreprocessor). Use the .processed attribute to check if an instance has already been preprocessed.

  • processed (bool)

Examples

Let’s assume you have a set of spectra for a collection of 128 stars. The data are aligned on the same wavelength grid with 2048 pixels. The data can therefore be stored in a 2D array with shape (128, 2048). You also have the errors on the fluxes, which are stored in a 2D array with the same shape. You can create an instance of OutputData to store this data. In the example below, we will generate some random data to represent this case (for the sake of illustration):

>>> import jax.numpy as jnp
>>> import jax.random as jrnd
>>> from pollux.data import OutputData
>>> rngs = jrnd.split(jrnd.PRNGKey(0), 2)
>>> spectra = jrnd.uniform(rngs[0], minval=0, maxval=10, shape=(128, 2048))
>>> spectra_err = jrnd.uniform(rngs[1], minval=0.1, maxval=1, shape=spectra.shape)
>>> flux_data = OutputData(data=spectra, err=spectra_err)
>>> flux_data

OutputData(data=f…[128,2048], err=f…[128,2048]) >>> assert flux_data.processed is False

We did not specify a preprocessor, so the data are not preprocessed even if we call .preprocess(). In this case, the processed data should equal the unprocessed data:

>>> tmp = flux_data.preprocess()
>>> assert tmp.processed
>>> assert jnp.all(tmp.data == flux_data.data)

We can instead specify a data preprocessor to rescale and center the input data. For this, we use the ShiftScalePreprocessor, which centers the data on the specified location and scales the data by default along axis=0:

>>> from pollux.data import ShiftScalePreprocessor
>>> flux_data = OutputData(

… data=spectra, … err=spectra_err, … preprocessor=ShiftScalePreprocessor.from_data(spectra) … ) >>> processed_data = flux_data.preprocess() >>> assert processed_data.processed >>> assert jnp.allclose(jnp.mean(processed_data.data, axis=0), 0.0, atol=1e-5) >>> assert jnp.allclose(jnp.std(processed_data.data, axis=0), 1.0, atol=1e-5)

Parameters:
  • processed (bool)

  • data (Float[Array, '#stars output'])

  • err (Float[Array, '#stars output'])

  • preprocessor (AbstractPreprocessor)

preprocess()#

Preprocess the data using the preprocessor.

Return type:

OutputData

preprocessor: AbstractPreprocessor = NullPreprocessor()#
Parameters:
  • X (Float[Array, '#stars output'])

  • inverse (bool)

Return type:

Float[Array, ’#stars output’]

processed: bool = False#
unprocess(data: Float[Array, '#stars output'] | OutputData | None = None)#

Unprocess the data using the preprocessor.

Parameters:

data (Union[Float[Array, '#stars output'], OutputData, None]) – The data to unprocess. If None, the instance’s data will be unprocessed.

Return type:

OutputData

data: Float[Array, '#stars output']#
err: Float[Array, '#stars output']#
class pollux.data.PolluxData(**kwargs: OutputData)#

Bases: ImmutableMap[str, OutputData]

Parameters:

kwargs (OutputData)

get(key: K, /, default: V | _T | None = None)#

Get an item by key.

Examples

>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> d.get("a")
1
>>> d.get("c")
>>> d.get("c", 3)
3
Parameters:
Return type:

Union[TypeVar(V), TypeVar(_T), None]

items()#

Return the items.

Return type:

ItemsView[TypeVar(K), TypeVar(V)]

Examples

>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> d.items()
dict_items([('a', 1), ('b', 2)])
keys()#

Return the keys.

Return type:

KeysView[TypeVar(K)]

Examples

>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> d.keys()
dict_keys(['a', 'b'])
preprocess()#

Preprocess all output data.

Return type:

PolluxData

tree_flatten()#

Flatten dict to the values (and keys).

This is used for JAX’s tree flattening.

Return type:

tuple[tuple[TypeVar(V), ...], tuple[TypeVar(K), ...]]

Examples

>>> import jax
>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> d.tree_flatten()
((1, 2), ('a', 'b'))
>>> jax.tree.flatten(d)
([1, 2], PyTreeDef(CustomNode(ImmutableMap[('a', 'b')], [*, *])))
classmethod tree_unflatten(aux_data: Annotated[tuple[K, ...], Doc('The keys.')], children: Annotated[tuple[V, ...], Doc('The values.')])#

Unflatten into an ImmutableMap from the keys and values.

This is used for JAX’s tree un-flattening.

Examples

>>> import jax
>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> flat = d.tree_flatten()
>>> ImmutableMap.tree_unflatten(*flat)
ImmutableMap({1: 'a', 2: 'b'})
>>> jax.tree.unflatten(jax.tree.structure(d), flat)
ImmutableMap({'a': (1, 2), 'b': ('a', 'b')})
Parameters:
Return type:

ImmutableMap[TypeVar(K), TypeVar(V)]

unprocess(data: PolluxData | dict[str, Float[Array, '#stars output']] | None = None, ignore_missing: bool = False)#

Unprocess all output data.

Parameters:
  • data (Union[PolluxData, dict[str, Float[Array, '#stars output']], None]) – Data to unprocess. If None, unprocess self.

  • ignore_missing (bool) – If True, only unprocess keys that are present in both the instance and the input data. If False (default), raise an error if keys don’t match.

Return type:

PolluxData

values()#

Return the values.

Return type:

ValuesView[TypeVar(V)]

Examples

>>> from xmmutablemap import ImmutableMap
>>> d = ImmutableMap(a=1, b=2)
>>> d.values()
dict_values([1, 2])
class pollux.data.ShiftScalePreprocessor(loc: Any, scale: Any)#

Bases: AbstractPreprocessor

Shift and then scale the data.

The data are shifted by the specified location parameter loc and then scaled by the scale parameter.

Use the from_data() and from_data_percentiles() class methods to compute the preprocessing parameters from specified data. The from_data() method computes the mean and standard deviation of the data along the specified axis, while the from_data_percentiles() method computes the median and the difference between the specified percentiles as the scale.

Examples

The default way of computing the preprocessing parameters uses the from_data() class method, which computes the mean and standard deviation of the data along axis=0:

>>> import jax.numpy as jnp
>>> data = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> from pollux.data import ShiftScalePreprocessor
>>> preprocessor = ShiftScalePreprocessor.from_data(data)
>>> processed_data = preprocessor.transform(data)
>>> assert jnp.allclose(jnp.mean(processed_data, axis=0), 0.0, atol=1e-5)
>>> assert jnp.allclose(jnp.std(processed_data, axis=0), 1.0, atol=1e-5)

To instead use the mean and standard deviation computed over all axes at the same time, set the axis to None:

>>> preprocessor = ShiftScalePreprocessor.from_data(data, axis=None)
>>> processed_data = preprocessor.transform(data)
>>> assert jnp.allclose(processed_data, (data - jnp.mean(data)) / jnp.std(data), atol=1e-5)

An alternative way of computing the preprocessing parameters uses the from_data_percentiles() class method, which computes the median and the difference between the specified percentiles as the scale. Here we will specify using (1/2 times) the difference of the 84th and 16th percentile values as the scale:

>>> preprocessor = ShiftScalePreprocessor.from_data_percentiles(

… data, scale_percentiles=(5.0, 95.0) … ) >>> processed_data = preprocessor.transform(data) >>> assert jnp.allclose(jnp.median(processed_data, axis=0), 0.0, atol=1e-4)

Parameters:
classmethod from_data(data: Float[Array, '#stars output'], axis: int = 0)#

Compute preprocessing parameters from data.

Parameters:
  • data (Float[Array, '#stars output']) – The data to preprocess.

  • axis (int) – The axis along which to compute the mean and standard deviation.

Return type:

ShiftScalePreprocessor

classmethod from_data_percentiles(data: Float[Array, '#stars output'], loc_percentile: float = 50.0, scale_percentiles: tuple[float, float] = (16.0, 84.0), axis: int = 0)#

Compute preprocessing parameters from data.

Parameters:
  • data (Float[Array, '#stars output']) – The data to preprocess.

  • percentile_low – The lower percentile to use for computing the scale.

  • percentile_high – The higher / upper percentile to use for computing the scale.

  • axis (int) – The axis along which to compute the mean and standard deviation.

  • loc_percentile (float)

  • scale_percentiles (tuple[float, float])

Return type:

ShiftScalePreprocessor

inverse_transform(X: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

inverse_transform_err(X_err: Float[Array, '#stars output'])#

Apply inverse preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

transform(X: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data.

Parameters:

X (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

transform_err(X_err: Float[Array, '#stars output'])#

Apply preprocessing transform to the input data uncertainties.

Parameters:

X_err (Float[Array, '#stars output'])

Return type:

Float[Array, '#stars output']

loc: Array#
scale: Array#