Skip to content
Open
2 changes: 2 additions & 0 deletions ctis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from . import scenes
from . import instruments
from . import inverters

__all__ = [
"scenes",
"instruments",
"inverters",
]
17 changes: 17 additions & 0 deletions ctis/inverters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Inversion algorithms which can reconstruct scenes from observed images."""

from ._results import InversionResult
from ._inverters import AbstractInverter
from ._iterative import (
AbstractIterativeInverter,
MartInverter,
IterativeInversionResult,
)

__all__ = [
"AbstractInverter",
"AbstractIterativeInverter",
"MartInverter",
"InversionResult",
"IterativeInversionResult",
]
48 changes: 48 additions & 0 deletions ctis/inverters/_inverters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import abc
import dataclasses
import named_arrays as na
import ctis
from ._results import InversionResult

__all__ = [
"AbstractInverter",
]


@dataclasses.dataclass
class AbstractInverter(
abc.ABC,
):
"""
An interface describing an algorithm which can invert CTIS observations
to yield a reconstruction of the observed scene.
"""

@property
@abc.abstractmethod
def instrument(self) -> ctis.instruments.AbstractInstrument:
"""
A model of a CTIS instrument which transforms the radiance of an observed
scene to photons measured by the sensors.
"""

@abc.abstractmethod
def __call__(
self,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
**kwargs,
) -> InversionResult:
"""
Reconstruct a scene using the observed images.

Parameters
----------
images
The observed images used to calculate the reconstruction.
Must be evaluated on the same coordinates as
:attr:`~ctis.instruments.AbstractInstrument.coordinates_sensor`
attribute of :attr:`instrument`.
kwargs
Additional keyword arguments which can be used by subclass
implementations.
"""
28 changes: 28 additions & 0 deletions ctis/inverters/_inverters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import abc
import numpy as np
import named_arrays as na
import ctis


class AbstractTestAbstractInverter(
abc.ABC,
):

def test_instrument(self, a: ctis.inverters.AbstractInverter):
result = a.instrument
assert isinstance(result, ctis.instruments.AbstractInstrument)

def test__call__(
self,
a: ctis.inverters.AbstractInverter,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
):
result = a(images)

assert isinstance(result, ctis.inverters.InversionResult)

assert result.solution.sum() > 0
assert result.success
assert isinstance(result.message, str)
assert np.all(result.images == images)
assert result.inverter == a
8 changes: 8 additions & 0 deletions ctis/inverters/_iterative/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ._iterative import AbstractIterativeInverter, IterativeInversionResult
from ._mart import MartInverter

__all__ = [
"AbstractIterativeInverter",
"IterativeInversionResult",
"MartInverter",
]
63 changes: 63 additions & 0 deletions ctis/inverters/_iterative/_iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import ClassVar
import abc
import dataclasses
import named_arrays as na
from .. import AbstractInverter, InversionResult

__all__ = [
"AbstractIterativeInverter",
"IterativeInversionResult",
]


@dataclasses.dataclass
class AbstractIterativeInverter(
AbstractInverter,
):
"""
An abstract inversion algorithm which reconstructs an observed scene
using iterative methods.

These methods will apply some operation repeatedly until a specified
convergence criteria is met.
"""

axis_iteration: ClassVar[str] = "iteration"
"""The logical axis associated with changing iteration index."""

@property
@abc.abstractmethod
def num_iteration(self) -> int:
"""
The maximum number of iterations to perform.

If convergence is not reached before this number is exceeded,
a warning is raised and an unsuccessful result is returned.
"""


@dataclasses.dataclass
class IterativeInversionResult(
InversionResult,
):
"""The results of an iterative inversion attempt."""

inverter: AbstractIterativeInverter

num_iteration: int
"""The number of iterations performed by the inverter."""

merit: na.ScalarArray
"""The value of the merit function for each iteration."""

merit_name: str
"""Human-readable name of the merit function."""

@property
def iteration(self) -> na.ScalarArray:
"""The iteration value for each iteration."""
return na.arange(
start=0,
stop=self.num_iteration,
axis=self.inverter.axis_iteration,
)
12 changes: 12 additions & 0 deletions ctis/inverters/_iterative/_iterative_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import ctis
from .._inverters_test import AbstractTestAbstractInverter


class AbstractTestAbstractIterativeInverter(
AbstractTestAbstractInverter,
):

def test_num_iteration(self, a: ctis.inverters.AbstractIterativeInverter):
result = a.num_iteration
assert isinstance(result, int)
assert result > 0
5 changes: 5 additions & 0 deletions ctis/inverters/_iterative/_mart/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._mart import MartInverter

__all__ = [
"MartInverter",
]
Loading
Loading