"""
Output class objects.
The classes provided by this module exist to simplify access to large datasets created within C.
Fundamentally, ownership of the data belongs to these classes, and the C functions merely accesses
this and fills it. The various boxes and lightcones associated with each output are available as
instance attributes. Along with the output data, each output object contains the various input
parameter objects necessary to define it.
.. warning:: These should not be instantiated or filled by the user, but always handled
as output objects from the various functions contained here. Only the data
within the objects should be accessed.
"""
from __future__ import annotations
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum
from functools import cached_property
from typing import Any, Self
import attrs
import numpy as np
from astropy import units as u
from astropy.cosmology import z_at_value
from bidict import bidict
from .._cfg import config
from ..c_21cmfast import lib
from .arrays import Array
from .exceptions import _process_exitcode
from .inputs import (
AstroOptions,
AstroParams,
CosmoParams,
InputParameters,
InputStruct,
MatterOptions,
SimulationOptions,
)
from .structs import StructWrapper
[docs]
logger = logging.getLogger(__name__)
_ALL_OUTPUT_STRUCTS = {}
def _arrayfield(optional: bool = False, **kw):
if optional:
return attrs.field(
default=None,
validator=attrs.validators.optional(attrs.validators.instance_of(Array)),
eq=False,
type=Array,
)
else:
return attrs.field(
validator=attrs.validators.instance_of(Array),
eq=False,
type=Array,
)
class _HashType(Enum):
user_cosmo = 0
zgrid = 1
full = 2
@attrs.define(slots=False, kw_only=True)
class OutputStruct(ABC):
"""Base class for any class that wraps a C struct meant to be output from a C function."""
_meta = False
_c_compute_function = None
_compat_hash = _HashType.full
_TYPEMAP = bidict({"float32": "float *", "float64": "double *", "int32": "int *"})
inputs: InputParameters = attrs.field(
validator=attrs.validators.instance_of(InputParameters)
)
dummy: bool = attrs.field(default=False, converter=bool)
initial: bool = attrs.field(default=False, converter=bool)
@property
def _name(self):
"""The name of the struct."""
return self.__class__.__name__
def __init_subclass__(cls):
"""Store subclasses for easy access."""
if not cls._meta:
_ALL_OUTPUT_STRUCTS[cls.__name__] = cls
return super().__init_subclass__()
@property
def simulation_options(self) -> SimulationOptions:
"""The SimulationOptions object for this output struct."""
return self.inputs.simulation_options
@property
def matter_options(self) -> MatterOptions:
"""The SimulationOptions object for this output struct."""
return self.inputs.matter_options
@property
def cosmo_params(self) -> CosmoParams:
"""The CosmoParams object for this output struct."""
return self.inputs.cosmo_params
@property
def astro_params(self) -> AstroParams:
"""The AstroParams object for this output struct."""
return self.inputs.astro_params
@property
def astro_options(self) -> AstroOptions:
"""The AstroOptions object for this output struct."""
return self.inputs.astro_options
def _inputs_compatible_with(self, other: OutputStruct | InputParameters) -> bool:
"""Check whether this objects' inputs are compatible with another object's.
This check is sensitive to the fact that the other object may be at a different
level of the simulation heirarchy, and therefore may be compatible even if the
params are different. As long as they are the same at the level higher than the
minimum level of the simulation, they are considered compatible.
"""
if not isinstance(other, OutputStruct | InputParameters):
return False
if isinstance(other, InputParameters):
# Compare at the level required by this object only
return getattr(self.inputs, f"_{self._compat_hash.name}_hash") == getattr(
other, f"_{self._compat_hash.name}_hash"
)
min_req = min(self._compat_hash.value, other._compat_hash.value)
min_req = _HashType(min_req)
return getattr(self.inputs, f"_{min_req.name}_hash") == getattr(
other.inputs, f"_{min_req.name}_hash"
)
@property
def arrays(self) -> dict[str, Array]:
"""A dictionary of Array objects whose memory is shared between this object and the C backend."""
me = attrs.asdict(self, recurse=False)
return {k: x for k, x in me.items() if isinstance(x, Array)}
@cached_property
def struct(self) -> StructWrapper:
"""The python-wrapped struct associated with this input object."""
return StructWrapper(self._name)
@cached_property
def cstruct(self) -> StructWrapper:
"""The object pointing to the memory accessed by C-code for this struct."""
return self.struct.cstruct
def _init_arrays(self):
for k, array in self.arrays.items():
# Don't initialize C-based pointers or already-inited stuff, or stuff
# that's computed on disk (if it's on disk, accessing the array should
# just give the computed version, which is what we would want, not a
# zero-inited array).
if array.state.c_memory or array.state.initialized or array.state.on_disk:
continue
setattr(self, k, array.initialize())
@property
def random_seed(self) -> int:
"""The random seed for this particular instance."""
return self.inputs.random_seed
def push_to_backend(self):
"""Push the current state of the object with the underlying C-struct.
This will link any memory initialized by numpy in this object with the underlying
C-struct, and also update the C struct with any values in the python object.
"""
# Initialize all uninitialized arrays.
self._init_arrays()
for name, array in self.arrays.items():
# We do *not* set COMPUTED_ON_DISK items to the C-struct here, because we have no
# way of knowing (in this function) what is required to load in, and we don't want
# to unnecessarily load things in. We leave it to the user to ensure that all
# required arrays are loaded into memory before calling this function.
if array.state.initialized:
self.struct.expose_to_c(array, name)
for k in self.struct.primitive_fields:
if getattr(self, k) is not None:
setattr(self.cstruct, k, getattr(self, k))
def pull_from_backend(self):
"""Sync the current state of the object with the underlying C-struct.
This will pull any primitives calculated in the backend to the python object.
Arrays are passed in as pointers, and do not need to be copied back.
"""
for k in self.struct.primitive_fields:
setattr(self, k, getattr(self.cstruct, k))
def get(self, ary: str | Array):
"""If possible, load an array from disk, storing it and returning the underlying array."""
if isinstance(ary, str):
name = ary
try:
ary = self.arrays[ary]
except KeyError as e:
try:
return getattr(self, ary) # could be a different attribute...
except AttributeError:
raise AttributeError(f"The array {ary} does not exist") from e
elif names := [name for name, x in self.arrays.items() if x is ary]:
name = names[0]
else:
raise ValueError("The given array is not a part of this instance.")
if not ary.state.on_disk and not ary.state.initialized:
raise ValueError(f"Array '{name}' is not on disk and not initialized.")
if ary.state.on_disk and not ary.state.computed_in_mem:
ary = ary.loaded_from_disk()
setattr(self, name, ary)
return ary.value
def set(self, name: str, value: Any):
"""Set the value of an array."""
if name not in self.arrays:
try:
setattr(self, name, value)
except AttributeError:
raise AttributeError(f"The attribute '{name}' does not exist") from None
else:
setattr(self, name, self.arrays[name].with_value(value))
def prepare(
self,
flush: Sequence[str] | None = None,
keep: Sequence[str] | None = None,
force: bool = False,
):
"""Prepare the instance for being passed to another function.
This will flush all arrays in "flush" from memory, and ensure all arrays
in "keep" are in memory. At least one of these must be provided. By default,
the complement of the given parameter is all flushed/kept.
Parameters
----------
flush
Arrays to flush out of memory. Note that if no file is associated with this
instance, these arrays will be lost forever.
keep
Arrays to keep or load into memory. Note that if these do not already
exist, they will be loaded from file (if the file exists). Only one of
``flush`` and ``keep`` should be specified.
force
Whether to force flushing arrays even if no disk storage exists.
"""
if flush is None and keep is None:
raise ValueError("Must provide either flush or keep")
if flush is not None and keep is None:
keep = [k for k in self.arrays if k not in flush]
elif flush is None:
flush = [
k
for k, array in self.arrays.items()
if k not in keep and array.state.initialized
]
flush = flush or []
keep = keep or []
for k in flush:
self._remove_array(k, force=force)
# For everything we want to keep, we check if it is computed in memory,
# and if not, load it from disk.
for k in keep:
self.get(k)
def _remove_array(self, k: str, *, force=False):
array = self.arrays[k]
state = array.state
if not state.initialized:
warnings.warn(
f"Trying to remove array that isn't yet created: {k}", stacklevel=2
)
return
if state.computed_in_mem and not state.on_disk and not force:
# if we don't have the array on disk, don't purge unless we really want to
warnings.warn(
f"Trying to purge array '{k}' from memory that hasn't been stored! Use force=True if you meant to do this.",
stacklevel=2,
)
return
if state.c_has_active_memory:
lib.free(getattr(self.cstruct, k))
setattr(self, k, array.without_value())
def purge(self, force=False):
"""Flush all the boxes out of memory.
Parameters
----------
force
Whether to force the purge even if no disk storage exists.
"""
self.prepare(keep=[], force=force)
def load_all(self):
"""Load all possible arrays into memory."""
for x in self.arrays:
self.get(x)
@property
def is_computed(self) -> bool:
"""Whether this instance has been computed at all.
This is true either if the current instance has called :meth:`compute`,
or if it has a current existing :attr:`path` pointing to stored data,
or if such a path exists.
Just because the instance has been computed does *not* mean that all
relevant quantities are available -- some may have been purged from
memory without writing. Use :meth:`has` to check whether certain arrays
are available.
"""
return any(v.state.is_computed for v in self.arrays.values())
def ensure_arrays_computed(self, *arrays, load=False) -> bool:
"""Check if the given arrays are computed (not just initialized)."""
if not self.is_computed:
return False
computed = all(self.arrays[k].state.is_computed for k in arrays)
if computed and load:
self.prepare(keep=arrays, flush=[])
return computed
def ensure_arrays_inited(self, *arrays, init=False) -> bool:
"""Check if the given arrays are initialized (or computed)."""
inited = all(self.arrays[k].state.initialized for k in arrays)
if init and not inited:
self._init_arrays()
return True
@abstractmethod
def get_required_input_arrays(self, input_box: Self) -> list[str]:
"""Return all input arrays required to compute this object."""
def ensure_input_computed(self, input_box: Self, load: bool = False) -> bool:
"""Ensure all the inputs have been computed."""
if input_box.dummy:
return True
arrays = self.get_required_input_arrays(input_box)
if input_box.initial:
return input_box.ensure_arrays_inited(*arrays, init=load)
return input_box.ensure_arrays_computed(*arrays, load=load)
def summarize(self, indent: int = 0) -> str:
"""Generate a string summary of the struct."""
indent = indent * " "
# print array type and column headings
out = (
f"\n{indent}{self.__class__.__name__:>25} "
+ " 1st: End: Min: Max: Mean: \n"
)
# print array extrema and means
for fieldname, array in self.arrays.items():
state = array.state
if not state.initialized:
out += f"{indent} {fieldname:>25}: uninitialized\n"
elif not state.is_computed:
out += f"{indent} {fieldname:>25}: initialized\n"
elif not state.computed_in_mem:
out += f"{indent} {fieldname:>25}: computed on disk\n"
else:
x = self.get(fieldname).flatten()
if len(x) > 0:
out += f"{indent} {fieldname:>25}: {x[0]:11.4e}, {x[-1]:11.4e}, {x.min():11.4e}, {x.max():11.4e}, {np.mean(x):11.4e}\n"
else:
out += f"{indent} {fieldname:>25}: size zero\n"
# print primitive fields
out += "".join(
f"{indent} {fieldname:>25}: {getattr(self, fieldname, 'non-existent')}\n"
for fieldname in self.struct.primitive_fields
)
return out
@classmethod
def _log_call_arguments(cls, *args):
logger.debug(f"Calling {cls._c_compute_function.__name__} with following args:")
for arg in args:
if isinstance(arg, OutputStruct):
for line in arg.summarize(indent=1).split("\n"):
logger.debug(line)
elif isinstance(arg, InputStruct):
for line in str(arg).split("\n"):
logger.debug(f" {line}")
else:
logger.debug(f" {arg}")
def _ensure_arguments_exist(self, *args):
for arg in args:
if (
isinstance(arg, OutputStruct)
and not arg.dummy
and not self.ensure_input_computed(arg, load=True)
):
raise ValueError(
f"Trying to use {arg.__class__.__name__} to compute "
f"{self.__class__.__name__}, but some required arrays "
f"are not computed!\nArrays required: "
f"{self.get_required_input_arrays(arg)}\n"
f"Current State: {[(k, str(v.state)) for k, v in self.arrays.items()]}"
)
def _compute(self, allow_already_computed: bool = False, *args):
"""Compute the actual function that fills this struct."""
# Check that all required inputs are really computed, and load them into memory
# if they're not already.
self._ensure_arguments_exist(*args)
# Write a detailed message about call arguments if debug turned on.
if logger.getEffectiveLevel() <= logging.DEBUG:
self._log_call_arguments(*args)
# Construct the args. All StructWrapper objects need to actually pass their
# underlying cstruct, rather than themselves.
inputs = [
arg.cstruct if isinstance(arg, OutputStruct | InputStruct) else arg
for arg in args
]
# Sync the python/C memory
self.push_to_backend()
for arg in args:
if isinstance(arg, OutputStruct):
arg.push_to_backend()
# Ensure we haven't already tried to compute this instance.
if self.is_computed and not allow_already_computed:
raise ValueError(
f"You are trying to compute {self.__class__.__name__}, but it has already been computed."
)
# Perform the C computation
try:
exitcode = self._c_compute_function(*inputs, self.cstruct)
except TypeError as e:
logger.error(f"Arguments to {self._c_compute_function.__name__}: {inputs}")
raise e
_process_exitcode(exitcode, self._c_compute_function, args)
for name, array in self.arrays.items():
setattr(self, name, array.computed())
self.pull_from_backend()
return self
@classmethod
@abstractmethod
def new(cls, inputs: InputParameters, **kwargs) -> Self:
"""Instantiate the class from InputParameters."""
def get_full_size(self) -> int:
"""Return the size of the object in bytes.
This represents the size of the object if it is fully initialized/computed
and all in memory. Equivalently, it is close to the file size on disk.
"""
size = 0
for ary in self.arrays.values():
size += np.prod(ary.shape) * np.dtype(ary.dtype).itemsize
return size
@attrs.define(slots=False, kw_only=True)
class InitialConditions(OutputStruct):
"""A class representing an InitialConditions C-struct."""
_c_compute_function = lib.ComputeInitialConditions
_meta = False
_compat_hash = _HashType.user_cosmo
lowres_density = _arrayfield()
lowres_vx = _arrayfield(optional=True)
lowres_vy = _arrayfield(optional=True)
lowres_vz = _arrayfield(optional=True)
hires_density = _arrayfield()
hires_vx = _arrayfield(optional=True)
hires_vy = _arrayfield(optional=True)
hires_vz = _arrayfield(optional=True)
lowres_vx_2LPT = _arrayfield(optional=True)
lowres_vy_2LPT = _arrayfield(optional=True)
lowres_vz_2LPT = _arrayfield(optional=True)
hires_vx_2LPT = _arrayfield(optional=True)
hires_vy_2LPT = _arrayfield(optional=True)
hires_vz_2LPT = _arrayfield(optional=True)
lowres_vcb = _arrayfield(optional=True)
@classmethod
def new(cls, inputs: InputParameters, **kw) -> Self:
"""Create a new instance, given a set of input parameters."""
shape = (inputs.simulation_options.HII_DIM,) * 2 + (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.HII_DIM
),
)
hires_shape = (inputs.simulation_options.DIM,) * 2 + (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.DIM
),
)
out = {
"lowres_density": Array(shape, dtype=np.float32),
"hires_density": Array(hires_shape, dtype=np.float32),
}
if inputs.matter_options.PERTURB_ON_HIGH_RES:
out |= {
"hires_vx": Array(hires_shape, dtype=np.float32),
"hires_vy": Array(hires_shape, dtype=np.float32),
"hires_vz": Array(hires_shape, dtype=np.float32),
}
else:
out |= {
"lowres_vx": Array(shape, dtype=np.float32),
"lowres_vy": Array(shape, dtype=np.float32),
"lowres_vz": Array(shape, dtype=np.float32),
}
if inputs.matter_options.PERTURB_ALGORITHM == "2LPT":
out |= {
"hires_vx_2LPT": Array(hires_shape, dtype=np.float32),
"hires_vy_2LPT": Array(hires_shape, dtype=np.float32),
"hires_vz_2LPT": Array(hires_shape, dtype=np.float32),
}
if not inputs.matter_options.PERTURB_ON_HIGH_RES:
out |= {
"lowres_vx_2LPT": Array(shape, dtype=np.float32),
"lowres_vy_2LPT": Array(shape, dtype=np.float32),
"lowres_vz_2LPT": Array(shape, dtype=np.float32),
}
if inputs.matter_options.USE_RELATIVE_VELOCITIES:
out["lowres_vcb"] = Array(shape, dtype=np.float32)
return cls(inputs=inputs, **out, **kw)
def prepare_for_perturb(self, force: bool = False):
"""Ensure the ICs have all the boxes loaded for perturb, but no extra."""
keep = ["hires_density"]
if not self.matter_options.PERTURB_ON_HIGH_RES:
keep.append("lowres_density")
keep.append("lowres_vx")
keep.append("lowres_vy")
keep.append("lowres_vz")
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
keep.append("lowres_vx_2LPT")
keep.append("lowres_vy_2LPT")
keep.append("lowres_vz_2LPT")
else:
keep.append("hires_vx")
keep.append("hires_vy")
keep.append("hires_vz")
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
keep.append("hires_vx_2LPT")
keep.append("hires_vy_2LPT")
keep.append("hires_vz_2LPT")
if self.matter_options.USE_RELATIVE_VELOCITIES:
keep.append("lowres_vcb")
self.prepare(keep=keep, force=force)
def prepare_for_spin_temp(self, force: bool = False):
"""Ensure ICs have all boxes required for spin_temp, and no more."""
keep = []
if self.matter_options.USE_RELATIVE_VELOCITIES:
keep.append("lowres_vcb")
if self.matter_options.lagrangian_source_grid:
if not self.matter_options.PERTURB_ON_HIGH_RES:
keep.append("lowres_density")
keep.append("lowres_vx")
keep.append("lowres_vy")
keep.append("lowres_vz")
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
keep.append("lowres_vx_2LPT")
keep.append("lowres_vy_2LPT")
keep.append("lowres_vz_2LPT")
else:
keep.append("hires_density")
keep.append("hires_vx")
keep.append("hires_vy")
keep.append("hires_vz")
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
keep.append("hires_vx_2LPT")
keep.append("hires_vy_2LPT")
keep.append("hires_vz_2LPT")
self.prepare(keep=keep, force=force)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
return []
def compute(self, allow_already_computed: bool = False):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.random_seed,
)
@attrs.define(slots=False, kw_only=True)
class OutputStructZ(OutputStruct):
"""The same as an OutputStruct, but containing a redshift."""
_meta = True
redshift: float = attrs.field(converter=float)
@classmethod
def dummy(cls):
"""Create a dummy instance with the given inputs."""
return cls.new(inputs=InputParameters(random_seed=1), redshift=-1.0, dummy=True)
@classmethod
def initial(cls, inputs):
"""Create a dummy instance with the given inputs."""
return cls.new(inputs=inputs, redshift=-1.0, initial=True)
@attrs.define(slots=False, kw_only=True)
class PerturbedField(OutputStructZ):
"""A class containing all perturbed field boxes."""
_c_compute_function = lib.ComputePerturbedField
_meta = False
_compat_hash = _HashType.zgrid
density = _arrayfield()
velocity_z = _arrayfield()
velocity_x = _arrayfield(optional=True)
velocity_y = _arrayfield(optional=True)
@classmethod
def new(cls, inputs: InputParameters, redshift: float, **kw) -> Self:
"""Create a new PerturbedField instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`PerturbedField`
constructor.
"""
dim = inputs.simulation_options.HII_DIM
shape = (dim, dim, int(inputs.simulation_options.NON_CUBIC_FACTOR * dim))
out = {
"density": Array(shape, dtype=np.float32),
"velocity_z": Array(shape, dtype=np.float32),
}
if inputs.matter_options.KEEP_3D_VELOCITIES:
out["velocity_x"] = Array(shape, dtype=np.float32)
out["velocity_y"] = Array(shape, dtype=np.float32)
return cls(inputs=inputs, redshift=redshift, **out, **kw)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if not isinstance(input_box, InitialConditions):
raise ValueError(
f"{type(input_box)} is not an input required for PerturbedField!"
)
# Always require hires_density
required += ["hires_density"]
if self.matter_options.PERTURB_ON_HIGH_RES:
required += ["hires_vx", "hires_vy", "hires_vz"]
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
required += ["hires_vx_2LPT", "hires_vy_2LPT", "hires_vz_2LPT"]
else:
required += ["lowres_density", "lowres_vx", "lowres_vy", "lowres_vz"]
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
required += [
"lowres_vx_2LPT",
"lowres_vy_2LPT",
"lowres_vz_2LPT",
]
if self.matter_options.USE_RELATIVE_VELOCITIES:
required.append("lowres_vcb")
return required
def compute(self, *, allow_already_computed: bool = False, ics: InitialConditions):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
ics,
)
@property
def velocity(self):
"""The velocity of the box in the 3rd dimension (for backwards compat)."""
return self.velocity_z # for backwards compatibility
@attrs.define(slots=False, kw_only=True)
class HaloCatalog(OutputStructZ):
"""A class containing all fields related to halos."""
_c_compute_function = lib.ComputeHaloCatalog
_meta = False
desc_redshift: float | None = attrs.field(default=None)
_compat_hash = _HashType.zgrid
halo_masses = _arrayfield()
star_rng = _arrayfield()
sfr_rng = _arrayfield()
xray_rng = _arrayfield()
halo_coords = _arrayfield()
n_halos: int = attrs.field(default=None)
buffer_size: int = attrs.field(default=None)
@classmethod
def new(
cls,
inputs: InputParameters,
redshift: float,
buffer_size: float | None = None,
**kw,
) -> Self:
"""Create a new PerturbedHaloCatalog instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`PerturbedHaloCatalog`
constructor.
"""
from .cfuncs import get_halo_catalog_buffer_size
if kw.get("dummy", False):
buffer_size = 0
elif buffer_size is None:
buffer_size = get_halo_catalog_buffer_size(
redshift=redshift,
inputs=inputs,
free_cosmo_tables=kw.get("free_cosmo_tables", False),
)
return cls(
inputs=inputs,
halo_masses=Array((buffer_size,), dtype=np.float32),
star_rng=Array((buffer_size,), dtype=np.float32),
sfr_rng=Array((buffer_size,), dtype=np.float32),
xray_rng=Array((buffer_size,), dtype=np.float32),
halo_coords=Array((buffer_size, 3), dtype=np.float32),
redshift=redshift,
buffer_size=buffer_size,
**kw,
)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, InitialConditions):
if self.matter_options.SOURCE_MODEL == "CHMF-SAMPLER":
# when the sampler is on, the grids are only needed for the first sample
if self.desc_redshift <= 0:
required += ["hires_density"]
required += ["lowres_density"]
# without the sampler, dexm needs the hires density at each redshift
else:
required += ["hires_density"]
elif isinstance(input_box, HaloCatalog):
if self.matter_options.SOURCE_MODEL == "CHMF-SAMPLER":
required += [
"halo_masses",
"halo_coords",
"star_rng",
"sfr_rng",
"xray_rng",
]
else:
raise ValueError(
f"{type(input_box)} is not an input required for HaloCatalog!"
)
return required
def compute(
self,
*,
descendant_halos: HaloCatalog,
ics: InitialConditions,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.desc_redshift,
self.redshift,
ics,
ics.random_seed,
descendant_halos,
)
@attrs.define(slots=False, kw_only=True)
class PerturbedHaloCatalog(OutputStructZ):
"""A class to hold a HaloCatalog whose coordinates are in real (Eulerian) space."""
_c_compute_function = lib.ComputePerturbedHaloCatalog
_meta = False
desc_redshift: float | None = attrs.field(default=None)
_compat_hash = _HashType.zgrid
halo_masses = _arrayfield()
halo_coords = _arrayfield()
sfr = _arrayfield()
stellar_masses = _arrayfield()
ion_emissivity = _arrayfield()
xray_emissivity = _arrayfield(optional=True)
fesc_sfr = _arrayfield(optional=True)
stellar_mini = _arrayfield(optional=True)
sfr_mini = _arrayfield(optional=True)
n_halos: int = attrs.field(default=None)
buffer_size: int = attrs.field(default=None)
@classmethod
def new(
cls,
inputs: InputParameters,
redshift: float,
buffer_size: float,
**kw,
) -> Self:
"""Create a new PerturbedHaloCatalog instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`PerturbedHaloCatalog`
constructor.
"""
out = {
"halo_coords": Array((buffer_size, 3), dtype=np.float32),
"halo_masses": Array((buffer_size,), dtype=np.float32),
"stellar_masses": Array((buffer_size,), dtype=np.float32),
"sfr": Array((buffer_size,), dtype=np.float32),
"ion_emissivity": Array((buffer_size,), dtype=np.float32),
}
if inputs.astro_options.USE_TS_FLUCT:
out["xray_emissivity"] = Array((buffer_size,), dtype=np.float32)
if inputs.astro_options.INHOMO_RECO:
out["fesc_sfr"] = Array((buffer_size,), dtype=np.float32)
if inputs.astro_options.USE_MINI_HALOS:
out["stellar_mini"] = Array((buffer_size,), dtype=np.float32)
out["sfr_mini"] = Array((buffer_size,), dtype=np.float32)
return cls(
inputs=inputs,
redshift=redshift,
buffer_size=buffer_size,
**out,
**kw,
)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, InitialConditions):
if self.matter_options.PERTURB_ON_HIGH_RES:
required += ["hires_vx", "hires_vy", "hires_vz"]
else:
required += ["lowres_vx", "lowres_vy", "lowres_vz"]
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
required += [f"{k}_2LPT" for k in required]
if self.matter_options.USE_RELATIVE_VELOCITIES:
required += ["lowres_vcb"]
elif isinstance(input_box, TsBox):
if self.astro_options.USE_MINI_HALOS:
required += ["J_21_LW"]
elif isinstance(input_box, IonizedBox):
required += ["ionisation_rate_G12", "z_reion"]
elif isinstance(input_box, HaloCatalog):
required += [
"halo_coords",
"halo_masses",
]
else:
raise ValueError(
f"{type(input_box)} is not an input required for PerturbedHaloCatalog!"
)
return required
def compute(
self,
*,
ics: InitialConditions,
previous_spin_temp: TsBox,
previous_ionize_box: IonizedBox,
halo_catalog: HaloCatalog,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
ics,
previous_spin_temp,
previous_ionize_box,
halo_catalog,
)
@attrs.define(slots=False, kw_only=True)
class HaloBox(OutputStructZ):
"""A class containing all gridded halo properties."""
_meta = False
_c_compute_function = lib.ComputeHaloBox
count = _arrayfield(optional=True)
halo_mass = _arrayfield(optional=True)
halo_stars = _arrayfield(optional=True)
halo_stars_mini = _arrayfield(optional=True)
halo_sfr = _arrayfield()
halo_sfr_mini = _arrayfield(optional=True)
halo_xray = _arrayfield(optional=True)
n_ion = _arrayfield()
whalo_sfr = _arrayfield(optional=True)
log10_Mcrit_ACG_ave: float = attrs.field(default=None)
log10_Mcrit_MCG_ave: float = attrs.field(default=None)
@classmethod
def new(cls, inputs: InputParameters, redshift: float, **kw) -> Self:
"""Create a new HaloBox instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`HaloBox`
constructor.
"""
dim = inputs.simulation_options.HII_DIM
shape = (dim, dim, int(inputs.simulation_options.NON_CUBIC_FACTOR * dim))
out = {
"halo_sfr": Array(shape, dtype=np.float32),
"n_ion": Array(shape, dtype=np.float32),
}
if inputs.astro_options.USE_MINI_HALOS:
out["halo_sfr_mini"] = Array(shape, dtype=np.float32)
if inputs.astro_options.INHOMO_RECO:
out["whalo_sfr"] = Array(shape, dtype=np.float32)
if inputs.astro_options.USE_TS_FLUCT:
out["halo_xray"] = Array(shape, dtype=np.float32)
if config["EXTRA_HALOBOX_FIELDS"]:
out["count"] = Array(shape, dtype=np.int32)
out["halo_mass"] = Array(shape, dtype=np.float32)
out["halo_stars"] = Array(shape, dtype=np.float32)
if inputs.astro_options.USE_MINI_HALOS:
out["halo_stars_mini"] = Array(shape, dtype=np.float32)
return cls(
inputs=inputs,
redshift=redshift,
**out,
**kw,
)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, HaloCatalog):
if self.matter_options.has_discrete_halos:
required += [
"halo_coords",
"halo_masses",
"star_rng",
"sfr_rng",
"xray_rng",
]
elif isinstance(input_box, TsBox):
if self.astro_options.USE_MINI_HALOS:
required += ["J_21_LW"]
elif isinstance(input_box, IonizedBox):
required += ["ionisation_rate_G12", "z_reion"]
elif isinstance(input_box, InitialConditions):
if self.matter_options.PERTURB_ON_HIGH_RES:
required += ["hires_density", "hires_vx", "hires_vy", "hires_vz"]
else:
required += ["lowres_density", "lowres_vx", "lowres_vy", "lowres_vz"]
if self.matter_options.PERTURB_ALGORITHM == "2LPT":
required += [f"{k}_2LPT" for k in required if "_v" in k]
if self.matter_options.USE_RELATIVE_VELOCITIES:
required += ["lowres_vcb"]
else:
raise ValueError(f"{type(input_box)} is not an input required for HaloBox!")
return required
def compute(
self,
*,
initial_conditions: InitialConditions,
halo_catalog: HaloCatalog,
previous_spin_temp: TsBox,
previous_ionize_box: IonizedBox,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
initial_conditions,
halo_catalog,
previous_spin_temp,
previous_ionize_box,
)
def prepare_for_next_snapshot(self, next_z, force: bool = False):
"""Prepare the HaloBox for the next snapshot."""
# find maximum z
d_max_needed = (
self.cosmo_params.cosmo.comoving_distance(next_z)
+ self.astro_params.R_MAX_TS * u.Mpc
)
max_z_needed = z_at_value(
self.cosmo_params.cosmo.comoving_distance, d_max_needed
)
z_arr = np.array(self.inputs.node_redshifts)
# we need one redshift above the max z for interpolation, so find that value
last_z_above = (
z_arr[z_arr > max_z_needed].min()
if z_arr.max() > max_z_needed
else z_arr.max() + 1
)
# If we need the box, only keep the interpolated fields
keep = []
if self.redshift <= last_z_above:
if self.astro_options.USE_TS_FLUCT:
keep += ["halo_sfr", "halo_xray"]
if self.astro_options.USE_MINI_HALOS and self.astro_options.USE_TS_FLUCT:
keep += ["halo_sfr_mini"]
self.prepare(keep=keep, force=force)
@attrs.define(slots=False, kw_only=True)
class XraySourceBox(OutputStructZ):
"""A class containing the filtered sfr grids."""
_meta = False
_c_compute_function = lib.UpdateXraySourceBox
filtered_sfr = _arrayfield()
filtered_sfr_mini = _arrayfield(optional=True)
filtered_xray = _arrayfield()
mean_sfr = _arrayfield()
mean_sfr_mini = _arrayfield(optional=True)
mean_log10_Mcrit_LW = _arrayfield(optional=True)
@classmethod
def new(cls, inputs: InputParameters, redshift: float, **kw) -> Self:
"""Create a new XraySourceBox instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`XraySourceBox`
constructor.
"""
shape = (
(inputs.astro_params.N_STEP_TS,)
+ (inputs.simulation_options.HII_DIM,) * 2
+ (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.HII_DIM
),
)
)
out = {
"filtered_sfr": Array(shape, dtype=np.float32),
"filtered_xray": Array(shape, dtype=np.float32),
"mean_sfr": Array((inputs.astro_params.N_STEP_TS,), dtype=np.float64),
}
if inputs.astro_options.USE_MINI_HALOS:
out["filtered_sfr_mini"] = Array(shape, dtype=np.float32)
out["mean_sfr_mini"] = Array(
(inputs.astro_params.N_STEP_TS,), dtype=np.float64
)
out["mean_log10_Mcrit_LW"] = Array(
(inputs.astro_params.N_STEP_TS,), dtype=np.float64
)
return cls(
inputs=inputs,
redshift=redshift,
**out,
**kw,
)
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if not isinstance(input_box, HaloBox):
raise ValueError(f"{type(input_box)} is not an input required for HaloBox!")
required += ["halo_sfr", "halo_xray"]
if self.astro_options.USE_MINI_HALOS:
required += ["halo_sfr_mini"]
return required
def compute(
self,
*,
halobox: HaloBox,
R_inner,
R_outer,
R_ct,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
halobox,
R_inner,
R_outer,
R_ct,
)
@attrs.define(slots=False, kw_only=True)
class TsBox(OutputStructZ):
"""A class containing all spin temperature boxes."""
_c_compute_function = lib.ComputeTsBox
_meta = False
spin_temperature = _arrayfield()
xray_ionised_fraction = _arrayfield()
kinetic_temp_neutral = _arrayfield()
J_21_LW = _arrayfield(optional=True)
@classmethod
def new(cls, inputs: InputParameters, redshift: float, **kw) -> Self:
"""Create a new TsBox instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`TsBox`
constructor.
"""
shape = (inputs.simulation_options.HII_DIM,) * 2 + (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.HII_DIM
),
)
out = {
"spin_temperature": Array(shape, dtype=np.float32),
"xray_ionised_fraction": Array(shape, dtype=np.float32),
"kinetic_temp_neutral": Array(shape, dtype=np.float32),
}
if inputs.astro_options.USE_MINI_HALOS:
out["J_21_LW"] = Array(shape, dtype=np.float32)
return cls(inputs=inputs, redshift=redshift, **out, **kw)
@cached_property
def global_Ts(self):
"""Global (mean) spin temperature."""
if not self.is_computed:
raise AttributeError(
"global_Ts is not defined until the ionization calculation has been performed"
)
else:
return np.mean(self.get("spin_temperature"))
@cached_property
def global_Tk(self):
"""Global (mean) Tk."""
if not self.is_computed:
raise AttributeError(
"global_Tk is not defined until the ionization calculation has been performed"
)
else:
return np.mean(self.get("kinetic_temp_neutral"))
@cached_property
def global_x_e(self):
"""Global (mean) x_e."""
if not self.is_computed:
raise AttributeError(
"global_x_e is not defined until the ionization calculation has been performed"
)
else:
return np.mean(self.get("xray_ionised_fraction"))
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, InitialConditions):
if (
self.matter_options.USE_RELATIVE_VELOCITIES
and self.astro_options.USE_MINI_HALOS
):
required += ["lowres_vcb"]
elif isinstance(input_box, PerturbedField):
required += ["density"]
elif isinstance(input_box, TsBox):
required += [
"kinetic_temp_neutral",
"xray_ionised_fraction",
"spin_temperature",
]
if self.astro_options.USE_MINI_HALOS:
required += ["J_21_LW"]
elif isinstance(input_box, XraySourceBox):
if self.matter_options.lagrangian_source_grid:
required += ["filtered_sfr", "filtered_xray"]
if self.astro_options.USE_MINI_HALOS:
required += ["filtered_sfr_mini"]
else:
raise ValueError(
f"{type(input_box)} is not an input required for PerturbedHaloCatalog!"
)
return required
def compute(
self,
*,
cleanup: bool,
perturbed_field: PerturbedField,
xray_source_box: XraySourceBox,
prev_spin_temp: TsBox,
ics: InitialConditions,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
prev_spin_temp.redshift,
perturbed_field.redshift,
cleanup,
perturbed_field,
xray_source_box,
prev_spin_temp,
ics,
)
@attrs.define(slots=False, kw_only=True)
class IonizedBox(OutputStructZ):
"""A class containing all ionized boxes."""
_meta = False
_c_compute_function = lib.ComputeIonizedBox
neutral_fraction = _arrayfield()
ionisation_rate_G12 = _arrayfield()
mean_free_path = _arrayfield()
z_reion = _arrayfield()
cumulative_recombinations = _arrayfield(optional=True)
kinetic_temperature = _arrayfield()
unnormalised_nion = _arrayfield()
unnormalised_nion_mini = _arrayfield(optional=True)
log10_Mturnover_ave: float = attrs.field(default=None)
log10_Mturnover_MINI_ave: float = attrs.field(default=None)
mean_f_coll: float = attrs.field(default=None)
mean_f_coll_MINI: float = attrs.field(default=None)
@classmethod
def new(cls, inputs, redshift: float, **kw) -> Self:
"""Create a new IonizedBox instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`IonizedBox`
constructor.
"""
if (
inputs.astro_options.USE_MINI_HALOS
and not inputs.matter_options.lagrangian_source_grid
):
n_filtering = (
int(
np.log(
min(
inputs.astro_params.R_BUBBLE_MAX,
0.620350491 * inputs.simulation_options.BOX_LEN,
)
/ max(
inputs.astro_params.R_BUBBLE_MIN,
0.620350491
* inputs.simulation_options.BOX_LEN
/ inputs.simulation_options.HII_DIM,
)
)
/ np.log(inputs.astro_params.DELTA_R_HII_FACTOR)
)
+ 1
)
else:
n_filtering = 1
shape = (inputs.simulation_options.HII_DIM,) * 2 + (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.HII_DIM
),
)
filter_shape = (n_filtering, *shape)
out = {
"neutral_fraction": Array(shape, initfunc=np.ones, dtype=np.float32),
"ionisation_rate_G12": Array(shape, dtype=np.float32),
"mean_free_path": Array(shape, dtype=np.float32),
"z_reion": Array(shape, dtype=np.float32),
"kinetic_temperature": Array(shape, dtype=np.float32),
"unnormalised_nion": Array(filter_shape, dtype=np.float32),
}
if inputs.astro_options.INHOMO_RECO:
out["cumulative_recombinations"] = Array(shape, dtype=np.float32)
if (
inputs.astro_options.USE_MINI_HALOS
and not inputs.matter_options.lagrangian_source_grid
):
out["unnormalised_nion_mini"] = Array(filter_shape, dtype=np.float32)
return cls(inputs=inputs, redshift=redshift, **out, **kw)
@cached_property
def global_xH(self):
"""Global (mean) neutral fraction."""
if not self.is_computed:
raise AttributeError(
"global_xH is not defined until the ionization calculation has been performed"
)
else:
return np.mean(self.get("neutral_fraction"))
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, InitialConditions):
if (
self.matter_options.USE_RELATIVE_VELOCITIES
and self.matter_options.mass_dependent_zeta
):
required += ["lowres_vcb"]
elif isinstance(input_box, PerturbedField):
required += ["density"]
elif isinstance(input_box, TsBox):
required += ["kinetic_temp_neutral", "xray_ionised_fraction"]
if self.astro_options.USE_MINI_HALOS:
required += ["J_21_LW"]
elif isinstance(input_box, IonizedBox):
required += ["z_reion", "ionisation_rate_G12"]
if self.astro_options.INHOMO_RECO:
required += [
"cumulative_recombinations",
]
if (
self.matter_options.mass_dependent_zeta
and self.astro_options.USE_MINI_HALOS
):
required += [
"unnormalised_nion",
]
if self.matter_options.SOURCE_MODEL == "E-INTEGRAL":
required += [
"unnormalised_nion_mini",
]
elif isinstance(input_box, HaloBox):
required += ["n_ion"]
if self.astro_options.INHOMO_RECO:
required += ["whalo_sfr"]
else:
raise ValueError(
f"{type(input_box)} is not an input required for IonizedBox!"
)
return required
def compute(
self,
*,
perturbed_field: PerturbedField,
prev_perturbed_field: PerturbedField,
prev_ionize_box,
spin_temp: TsBox,
halobox: HaloBox,
ics: InitialConditions,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
prev_perturbed_field.redshift,
perturbed_field,
prev_perturbed_field,
prev_ionize_box,
spin_temp,
halobox,
ics,
)
@attrs.define(slots=False, kw_only=True)
class BrightnessTemp(OutputStructZ):
"""A class containing the brightness temperature box."""
_c_compute_function = lib.ComputeBrightnessTemp
_meta = False
brightness_temp = _arrayfield()
tau_21 = _arrayfield(optional=True)
@classmethod
def new(cls, inputs: InputParameters, redshift: float, **kw) -> Self:
"""Create a new BrightnessTemp instance with the given inputs.
Parameters
----------
inputs : InputParameters
The input parameters defining the output struct.
redshift : float
The redshift at which to compute fields.
Other Parameters
----------------
All other parameters are passed through to the :class:`BrightnessTemp`
constructor.
"""
shape = (inputs.simulation_options.HII_DIM,) * 2 + (
int(
inputs.simulation_options.NON_CUBIC_FACTOR
* inputs.simulation_options.HII_DIM
),
)
out = {"brightness_temp": Array(shape, dtype=np.float32)}
if inputs.astro_options.USE_TS_FLUCT:
out["tau_21"] = Array(shape, dtype=np.float32)
return cls(
inputs=inputs,
redshift=redshift,
**out,
**kw,
)
@cached_property
def global_Tb(self):
"""Global (mean) brightness temperature."""
if not self.is_computed:
raise AttributeError(
"global_Tb is not defined until the ionization calculation has been performed"
)
else:
return np.mean(self.get("brightness_temp"))
def get_required_input_arrays(self, input_box: OutputStruct) -> list[str]:
"""Return all input arrays required to compute this object."""
required = []
if isinstance(input_box, PerturbedField):
required += ["density"]
elif isinstance(input_box, TsBox):
required += ["spin_temperature"]
elif isinstance(input_box, IonizedBox):
required += ["neutral_fraction"]
else:
raise ValueError(
f"{type(input_box)} is not an input required for BrightnessTemp!"
)
return required
def compute(
self,
*,
spin_temp: TsBox,
ionized_box: IonizedBox,
perturbed_field: PerturbedField,
allow_already_computed: bool = False,
):
"""Compute the function."""
return self._compute(
allow_already_computed,
self.redshift,
spin_temp,
ionized_box,
perturbed_field,
)