Skip to content

Commit b0f562a

Browse files
committed
Modified the image() and backproject() methods of AbstractInstrument to accept instances of na.FunctionArray.
1 parent 9f6df49 commit b0f562a

3 files changed

Lines changed: 30 additions & 19 deletions

File tree

ctis/instruments/_instruments.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AbstractInstrument(
3939
@abc.abstractmethod
4040
def image(
4141
self,
42-
scene: na.AbstractScalar,
42+
scene: na.AbstractScalar | na.AbstractFunctionArray,
4343
integrate: bool = True,
4444
noise: bool = True,
4545
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
@@ -66,7 +66,7 @@ def image(
6666
@abc.abstractmethod
6767
def backproject(
6868
self,
69-
image: na.AbstractScalar,
69+
image: na.AbstractScalar | na.AbstractFunctionArray,
7070
integrate: bool = True,
7171
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
7272
"""
@@ -233,11 +233,18 @@ def _energy_per_photon(self) -> u.Quantity | na.AbstractScalar:
233233

234234
def image(
235235
self,
236-
scene: na.AbstractScalar,
236+
scene: na.AbstractScalar | na.AbstractFunctionArray,
237237
integrate: bool = True,
238238
noise: bool = True,
239239
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
240240

241+
if isinstance(scene, na.AbstractFunctionArray):
242+
if not np.all(scene.inputs == self.coordinates_scene):
243+
raise ValueError(
244+
"`scene.inputs` and `self.coordinates_scene` are not equal."
245+
)
246+
scene = scene.outputs
247+
241248
values_input = scene * self._volume_scene
242249

243250
values_input = values_input / self._energy_per_photon
@@ -279,10 +286,17 @@ def image(
279286

280287
def backproject(
281288
self,
282-
image: na.AbstractScalar,
289+
image: na.AbstractScalar | na.AbstractFunctionArray,
283290
integrate: bool = True,
284291
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
285292

293+
if isinstance(image, na.AbstractFunctionArray):
294+
if not np.all(image.inputs.position == self.coordinates_sensor.position):
295+
raise ValueError(
296+
"`image.inputs` and `self.coordinates_sensor` are not equal."
297+
)
298+
image = image.outputs
299+
286300
coordinates = self.coordinates_scene
287301

288302
axis_wavelength = self.axis_wavelength
@@ -514,7 +528,7 @@ def weights_transpose(self):
514528

515529
def image(
516530
self,
517-
scene: na.AbstractScalar,
531+
scene: na.AbstractScalar | na.AbstractFunctionArray,
518532
integrate: bool = True,
519533
noise: bool = True,
520534
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
@@ -529,7 +543,7 @@ def image(
529543

530544
def backproject(
531545
self,
532-
image: na.AbstractScalar,
546+
image: na.AbstractScalar | na.AbstractFunctionArray,
533547
integrate: bool = True,
534548
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
535549

ctis/instruments/_instruments_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class AbstractTestAbstractInstrument(
7070
def test_image(
7171
self,
7272
a: ctis.instruments.AbstractInstrument,
73-
scene: na.AbstractScalar,
73+
scene: na.AbstractScalar | na.AbstractFunctionArray,
7474
):
7575
result = a.image(scene)
7676
assert np.all(result.inputs.position == coordinates_sensor.position)
@@ -79,20 +79,23 @@ def test_image(
7979
@pytest.mark.parametrize(
8080
argnames="image",
8181
argvalues=[
82-
instrument_ideal.image(gaussians.outputs, noise=False).outputs,
82+
instrument_ideal.image(gaussians, noise=False),
8383
],
8484
)
8585
def test_backproject(
8686
self,
8787
a: ctis.instruments.AbstractInstrument,
88-
image: na.AbstractScalar,
88+
image: na.AbstractScalar | na.AbstractFunctionArray,
8989
):
9090
result = a.backproject(image)
9191

9292
assert np.all(result.inputs == coordinates_scene)
9393
assert result.outputs.sum() > 0
9494

95-
image_check = a.image(result.outputs, noise=False).outputs
95+
if isinstance(image, na.AbstractFunctionArray):
96+
image = image.outputs
97+
98+
image_check = a.image(result, noise=False).outputs
9699

97100
assert np.allclose(image.sum(), image_check.sum())
98101

docs/tutorials/ideal-instrument.ipynb

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,7 @@
501501
"tags": []
502502
},
503503
"outputs": [],
504-
"source": [
505-
"image = instrument.image(scene.outputs, integrate=False)"
506-
]
504+
"source": "image = instrument.image(scene, integrate=False)"
507505
},
508506
{
509507
"cell_type": "raw",
@@ -602,9 +600,7 @@
602600
"tags": []
603601
},
604602
"outputs": [],
605-
"source": [
606-
"image_sum = instrument.image(scene.outputs)"
607-
]
603+
"source": "image_sum = instrument.image(scene)"
608604
},
609605
{
610606
"cell_type": "raw",
@@ -633,9 +629,7 @@
633629
"tags": []
634630
},
635631
"outputs": [],
636-
"source": [
637-
"backprojected = instrument.backproject(image_sum.outputs)"
638-
]
632+
"source": "backprojected = instrument.backproject(image_sum)"
639633
},
640634
{
641635
"cell_type": "raw",

0 commit comments

Comments
 (0)