Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions ctis/instruments/_instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AbstractInstrument(
@abc.abstractmethod
def image(
self,
scene: na.AbstractScalar,
scene: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
noise: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
Expand All @@ -66,7 +66,7 @@ def image(
@abc.abstractmethod
def backproject(
self,
image: na.AbstractScalar,
image: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
"""
Expand Down Expand Up @@ -233,11 +233,18 @@ def _energy_per_photon(self) -> u.Quantity | na.AbstractScalar:

def image(
self,
scene: na.AbstractScalar,
scene: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
noise: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:

if isinstance(scene, na.AbstractFunctionArray):
if not np.all(scene.inputs == self.coordinates_scene):
raise ValueError(
"`scene.inputs` and `self.coordinates_scene` are not equal."
)
scene = scene.outputs

values_input = scene * self._volume_scene

values_input = values_input / self._energy_per_photon
Expand Down Expand Up @@ -279,10 +286,17 @@ def image(

def backproject(
self,
image: na.AbstractScalar,
image: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:

if isinstance(image, na.AbstractFunctionArray):
if not np.all(image.inputs.position == self.coordinates_sensor.position):
raise ValueError(
"`image.inputs` and `self.coordinates_sensor` are not equal."
)
image = image.outputs

coordinates = self.coordinates_scene

axis_wavelength = self.axis_wavelength
Expand Down Expand Up @@ -514,7 +528,7 @@ def weights_transpose(self):

def image(
self,
scene: na.AbstractScalar,
scene: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
noise: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
Expand All @@ -529,7 +543,7 @@ def image(

def backproject(
self,
image: na.AbstractScalar,
image: na.AbstractScalar | na.AbstractFunctionArray,
integrate: bool = True,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:

Expand Down
11 changes: 7 additions & 4 deletions ctis/instruments/_instruments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AbstractTestAbstractInstrument(
def test_image(
self,
a: ctis.instruments.AbstractInstrument,
scene: na.AbstractScalar,
scene: na.AbstractScalar | na.AbstractFunctionArray,
):
result = a.image(scene)
assert np.all(result.inputs.position == coordinates_sensor.position)
Expand All @@ -79,20 +79,23 @@ def test_image(
@pytest.mark.parametrize(
argnames="image",
argvalues=[
instrument_ideal.image(gaussians.outputs, noise=False).outputs,
instrument_ideal.image(gaussians, noise=False),
],
)
def test_backproject(
self,
a: ctis.instruments.AbstractInstrument,
image: na.AbstractScalar,
image: na.AbstractScalar | na.AbstractFunctionArray,
):
result = a.backproject(image)

assert np.all(result.inputs == coordinates_scene)
assert result.outputs.sum() > 0

image_check = a.image(result.outputs, noise=False).outputs
if isinstance(image, na.AbstractFunctionArray):
image = image.outputs

image_check = a.image(result, noise=False).outputs

assert np.allclose(image.sum(), image_check.sum())

Expand Down
12 changes: 3 additions & 9 deletions docs/tutorials/ideal-instrument.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,7 @@
"tags": []
},
"outputs": [],
"source": [
"image = instrument.image(scene.outputs, integrate=False)"
]
"source": "image = instrument.image(scene, integrate=False)"
},
{
"cell_type": "raw",
Expand Down Expand Up @@ -602,9 +600,7 @@
"tags": []
},
"outputs": [],
"source": [
"image_sum = instrument.image(scene.outputs)"
]
"source": "image_sum = instrument.image(scene)"
},
{
"cell_type": "raw",
Expand Down Expand Up @@ -633,9 +629,7 @@
"tags": []
},
"outputs": [],
"source": [
"backprojected = instrument.backproject(image_sum.outputs)"
]
"source": "backprojected = instrument.backproject(image_sum)"
},
{
"cell_type": "raw",
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ Documentation = "https://ctis.readthedocs.io/en/latest"
packages = ["ctis"]

[tool.setuptools_scm]

[tool.coverage.report]
exclude_also = [
"return NotImplemented",
"raise ValueError",
"raise NotImplementedError",
]
Loading