pollux.data#
Root module#
- class pollux.data.AbstractPreprocessor#
Bases:
ModuleBase class for data preprocessors.
- abstractmethod classmethod from_data(data: Float[Array, '#stars output'])#
Compute preprocessing parameters from data.
- Parameters:
data (
Float[Array, '#stars output'])- Return type:
- abstractmethod inverse_transform(X: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- abstractmethod inverse_transform_err(X_err: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- abstractmethod transform(X: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- abstractmethod transform_err(X_err: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- class pollux.data.NullPreprocessor#
Bases:
AbstractPreprocessorA preprocessor that does nothing to the input data.
Examples
>>> import jax.numpy as jnp >>> import jax.random as jrnd >>> from pollux.data import NullPreprocessor >>> data = jrnd.normal(jrnd.PRNGKey(0), shape=(1024, 10)) >>> preprocessor = NullPreprocessor() >>> new_data = preprocessor.transform(data) >>> assert jnp.all(new_data == data)
- classmethod from_data(*_: Any)#
Compute preprocessing parameters from data.
- Parameters:
_ (
Any)- Return type:
- inverse_transform(X: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- inverse_transform_err(X_err: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- transform(X: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- transform_err(X_err: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- class pollux.data.OutputData(data: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, err: ~typing.Any = <factory>, preprocessor: ~pollux.data.preprocessor.AbstractPreprocessor = NullPreprocessor(), processed: bool = False)#
Bases:
ModuleA container for single block of output data.
This class is used to store data for a single output of a model, such as fluxes for a collection of stars, or stellar labels for a set of stars, or other data like broadband magnitudes, etc. Each instance of this class should correspond to a single output data type (e.g., spectral fluxes should be a separate instance from stellar labels).
- Parameters:
data (array-like) – The output data.
err (array-like, optional) – The uncertainties (errors) of the output data.
preprocessor (AbstractPreprocessor, optional) – A preprocessor to apply to the data. For example, this might recenter the data on the mean and scale to unit variance (using the
NormalizePreprocessor). Use the.processedattribute to check if an instance has already been preprocessed.processed (bool)
Examples
Let’s assume you have a set of spectra for a collection of 128 stars. The data are aligned on the same wavelength grid with 2048 pixels. The data can therefore be stored in a 2D array with shape (128, 2048). You also have the errors on the fluxes, which are stored in a 2D array with the same shape. You can create an instance of
OutputDatato store this data. In the example below, we will generate some random data to represent this case (for the sake of illustration):>>> import jax.numpy as jnp >>> import jax.random as jrnd >>> from pollux.data import OutputData >>> rngs = jrnd.split(jrnd.PRNGKey(0), 2) >>> spectra = jrnd.uniform(rngs[0], minval=0, maxval=10, shape=(128, 2048)) >>> spectra_err = jrnd.uniform(rngs[1], minval=0.1, maxval=1, shape=spectra.shape) >>> flux_data = OutputData(data=spectra, err=spectra_err) >>> flux_data
OutputData(data=f…[128,2048], err=f…[128,2048]) >>> assert flux_data.processed is False
We did not specify a preprocessor, so the data are not preprocessed even if we call
.preprocess(). In this case, the processed data should equal the unprocessed data:>>> tmp = flux_data.preprocess() >>> assert tmp.processed >>> assert jnp.all(tmp.data == flux_data.data)
We can instead specify a data preprocessor to rescale and center the input data. For this, we use the
ShiftScalePreprocessor, which centers the data on the specified location and scales the data by default alongaxis=0:>>> from pollux.data import ShiftScalePreprocessor >>> flux_data = OutputData(
… data=spectra, … err=spectra_err, … preprocessor=ShiftScalePreprocessor.from_data(spectra) … ) >>> processed_data = flux_data.preprocess() >>> assert processed_data.processed >>> assert jnp.allclose(jnp.mean(processed_data.data, axis=0), 0.0, atol=1e-5) >>> assert jnp.allclose(jnp.std(processed_data.data, axis=0), 1.0, atol=1e-5)
- Parameters:
processed (
bool)data (Float[Array, '#stars output'])
err (Float[Array, '#stars output'])
preprocessor (AbstractPreprocessor)
- preprocess()#
Preprocess the data using the preprocessor.
- Return type:
-
preprocessor:
AbstractPreprocessor= NullPreprocessor()# - Parameters:
X (Float[Array, '#stars output'])
inverse (bool)
- Return type:
Float[Array, ’#stars output’]
- unprocess(data: Float[Array, '#stars output'] | OutputData | None = None)#
Unprocess the data using the preprocessor.
- Parameters:
data (
Union[Float[Array, '#stars output'],OutputData,None]) – The data to unprocess. If None, the instance’s data will be unprocessed.- Return type:
-
data:
Float[Array, '#stars output']#
-
err:
Float[Array, '#stars output']#
- class pollux.data.PolluxData(**kwargs: OutputData)#
Bases:
ImmutableMap[str,OutputData]- Parameters:
kwargs (
OutputData)
- get(key: K, /, default: V | _T | None = None)#
Get an item by key.
Examples
>>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> d.get("a") 1 >>> d.get("c") >>> d.get("c", 3) 3
- items()#
Return the items.
Examples
>>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> d.items() dict_items([('a', 1), ('b', 2)])
- keys()#
Return the keys.
Examples
>>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> d.keys() dict_keys(['a', 'b'])
- preprocess()#
Preprocess all output data.
- Return type:
- tree_flatten()#
Flatten dict to the values (and keys).
This is used for JAX’s tree flattening.
Examples
>>> import jax >>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> d.tree_flatten() ((1, 2), ('a', 'b'))
>>> jax.tree.flatten(d) ([1, 2], PyTreeDef(CustomNode(ImmutableMap[('a', 'b')], [*, *])))
- classmethod tree_unflatten(aux_data: Annotated[tuple[K, ...], Doc('The keys.')], children: Annotated[tuple[V, ...], Doc('The values.')])#
Unflatten into an ImmutableMap from the keys and values.
This is used for JAX’s tree un-flattening.
Examples
>>> import jax >>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> flat = d.tree_flatten() >>> ImmutableMap.tree_unflatten(*flat) ImmutableMap({1: 'a', 2: 'b'})
>>> jax.tree.unflatten(jax.tree.structure(d), flat) ImmutableMap({'a': (1, 2), 'b': ('a', 'b')})
- unprocess(data: PolluxData | dict[str, Float[Array, '#stars output']] | None = None, ignore_missing: bool = False)#
Unprocess all output data.
- Parameters:
data (
Union[PolluxData,dict[str,Float[Array, '#stars output']],None]) – Data to unprocess. If None, unprocess self.ignore_missing (
bool) – If True, only unprocess keys that are present in both the instance and the input data. If False (default), raise an error if keys don’t match.
- Return type:
- values()#
Return the values.
- Return type:
ValuesView[TypeVar(V)]
Examples
>>> from xmmutablemap import ImmutableMap >>> d = ImmutableMap(a=1, b=2) >>> d.values() dict_values([1, 2])
- class pollux.data.ShiftScalePreprocessor(loc: Any, scale: Any)#
Bases:
AbstractPreprocessorShift and then scale the data.
The data are shifted by the specified location parameter loc and then scaled by the scale parameter.
Use the
from_data()andfrom_data_percentiles()class methods to compute the preprocessing parameters from specified data. Thefrom_data()method computes the mean and standard deviation of the data along the specified axis, while thefrom_data_percentiles()method computes the median and the difference between the specified percentiles as the scale.Examples
The default way of computing the preprocessing parameters uses the
from_data()class method, which computes the mean and standard deviation of the data along axis=0:>>> import jax.numpy as jnp >>> data = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) >>> from pollux.data import ShiftScalePreprocessor >>> preprocessor = ShiftScalePreprocessor.from_data(data) >>> processed_data = preprocessor.transform(data) >>> assert jnp.allclose(jnp.mean(processed_data, axis=0), 0.0, atol=1e-5) >>> assert jnp.allclose(jnp.std(processed_data, axis=0), 1.0, atol=1e-5)
To instead use the mean and standard deviation computed over all axes at the same time, set the axis to None:
>>> preprocessor = ShiftScalePreprocessor.from_data(data, axis=None) >>> processed_data = preprocessor.transform(data) >>> assert jnp.allclose(processed_data, (data - jnp.mean(data)) / jnp.std(data), atol=1e-5)
An alternative way of computing the preprocessing parameters uses the
from_data_percentiles()class method, which computes the median and the difference between the specified percentiles as the scale. Here we will specify using (1/2 times) the difference of the 84th and 16th percentile values as the scale:>>> preprocessor = ShiftScalePreprocessor.from_data_percentiles(
… data, scale_percentiles=(5.0, 95.0) … ) >>> processed_data = preprocessor.transform(data) >>> assert jnp.allclose(jnp.median(processed_data, axis=0), 0.0, atol=1e-4)
- classmethod from_data(data: Float[Array, '#stars output'], axis: int = 0)#
Compute preprocessing parameters from data.
- Parameters:
data (
Float[Array, '#stars output']) – The data to preprocess.axis (
int) – The axis along which to compute the mean and standard deviation.
- Return type:
- classmethod from_data_percentiles(data: Float[Array, '#stars output'], loc_percentile: float = 50.0, scale_percentiles: tuple[float, float] = (16.0, 84.0), axis: int = 0)#
Compute preprocessing parameters from data.
- Parameters:
data (
Float[Array, '#stars output']) – The data to preprocess.percentile_low – The lower percentile to use for computing the scale.
percentile_high – The higher / upper percentile to use for computing the scale.
axis (
int) – The axis along which to compute the mean and standard deviation.loc_percentile (
float)
- Return type:
- inverse_transform(X: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- inverse_transform_err(X_err: Float[Array, '#stars output'])#
Apply inverse preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- transform(X: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data.
- Parameters:
X (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']
- transform_err(X_err: Float[Array, '#stars output'])#
Apply preprocessing transform to the input data uncertainties.
- Parameters:
X_err (
Float[Array, '#stars output'])- Return type:
Float[Array, '#stars output']