Source code for py21cmfast.wrapper.structs
"""Data structure wrappers for the C code."""
from __future__ import annotations
import contextlib
import logging
from typing import Any
import attrs
from bidict import bidict
from ..c_21cmfast import ffi
from .arrays import Array
[docs]
logger = logging.getLogger(__name__)
@attrs.define(slots=False)
class StructWrapper:
"""
A base-class python wrapper for C structures (not instances of them).
Provides simple methods for creating new instances and accessing field names and values.
To implement wrappers of specific structures, make a subclass with the same name as the
appropriate C struct (which must be defined in the C code that has been compiled to the ``ffi``
object), *or* use an arbitrary name, but set the ``_name`` attribute to the C struct name.
"""
_name: str = attrs.field(converter=str)
cstruct = attrs.field(default=None)
_ffi = attrs.field(default=ffi)
_TYPEMAP = bidict({"float32": "float *", "float64": "double *", "int32": "int *"})
@_name.default
def _name_default(self):
return self.__class__.__name__
def __init__(self, *args):
"""Perform custom initializion actions.
This instantiates the memory associated with the C struct, attached to this inst.
"""
self.__attrs_init__(*args)
self.cstruct = self._new()
def _new(self):
"""Return a new empty C structure corresponding to this class."""
return self._ffi.new(f"struct {self._name}*")
@property
def fields(self) -> dict[str, Any]:
"""A list of fields of the underlying C struct (a list of tuples of "name, type")."""
return dict(self._ffi.typeof(self.cstruct[0]).fields)
@property
def fieldnames(self) -> list[str]:
"""A list of names of fields of the underlying C struct."""
return [f for f, t in self.fields.items()]
@property
def pointer_fields(self) -> list[str]:
"""A list of names of fields which have pointer type in the C struct."""
return [f for f, t in self.fields.items() if t.type.kind == "pointer"]
@property
def primitive_fields(self) -> list[str]:
"""The list of names of fields which have primitive type in the C struct."""
return [f for f, t in self.fields.items() if t.type.kind == "primitive"]
def __getstate__(self):
"""Return the current state of the class without pointers."""
return {
k: v
for k, v in self.__dict__.items()
if k not in ["_strings", "cstruct", "_ffi"]
}
def expose_to_c(self, array: Array, name: str):
"""Expose the memory of a particular Array to the backend C code."""
if not array.state.initialized:
raise ValueError("Array must be initialized before exposing to C")
def _ary2buf(ary):
return self._ffi.cast(
self._TYPEMAP[ary.dtype.name], self._ffi.from_buffer(ary)
)
try:
setattr(self.cstruct, name, _ary2buf(array.value))
except TypeError as e:
raise TypeError(f"Error setting {name}") from e
class StructInstanceWrapper:
"""A wrapper for *instances* of C structs.
This is as opposed to :class:`StructWrapper`, which is for the un-instantiated structs.
Parameters
----------
wrapped :
The reference to the C object to wrap (contained in the ``cffi.lib`` object).
ffi :
The ``cffi.ffi`` object.
"""
def __init__(self, wrapped, ffi):
self._cobj = wrapped
self._ffi = ffi
for nm, _tp in self._ffi.typeof(self._cobj).fields:
setattr(self, nm, getattr(self._cobj, nm))
# Get the name of the structure
self._ctype = self._ffi.typeof(self._cobj).cname.split()[-1]
def __setattr__(self, name, value):
"""Set an attribute of the instance, attempting to change it in the C struct as well."""
with contextlib.suppress(AttributeError):
setattr(self._cobj, name, value)
object.__setattr__(self, name, value)
def items(self):
"""Yield (name, value) pairs for each element of the struct."""
for nm, _tp in self._ffi.typeof(self._cobj).fields:
yield nm, getattr(self, nm)
def keys(self):
"""Return a list of names of elements in the struct."""
return [nm for nm, tp in self.items()]
def __iter__(self):
"""Iterate over the object like a dict."""
yield from self.keys()
def __repr__(self):
"""Return a unique representation of the instance."""
return (
self._ctype + "(" + ";".join(f"{k}={v!s}" for k, v in sorted(self.items()))
) + ")"
def filtered_repr(self, filter_params):
"""Get a fully unique representation of the instance that filters out some parameters.
Parameters
----------
filter_params : list of str
The parameter names which should not appear in the representation.
"""
return (
self._ctype
+ "("
+ ";".join(
f"{k}={v!s}" for k, v in sorted(self.items()) if k not in filter_params
)
) + ")"