Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ami.jobs.models import Job
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.admin.actions import make_post_processing_action
from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm
from ami.ml.post_processing.admin.rank_rollup_form import RankRollupActionForm
from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm
from ami.ml.post_processing.class_masking import ClassMaskingTask
from ami.ml.post_processing.rank_rollup import RankRollupTask
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
from ami.ml.tasks import remove_duplicate_classifications

Expand Down Expand Up @@ -552,6 +556,12 @@ def detections_count(self, obj) -> int:
scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk},
name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"),
)
run_class_masking = make_post_processing_action(
ClassMaskingTask,
ClassMaskingActionForm,
scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk},
name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"),
)

@admin.action(description="Recompute determination from current classifications and identifications")
def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any]) -> None:
Expand All @@ -568,7 +578,7 @@ def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any])
count += 1
self.message_user(request, f"Recomputed determination for {count} occurrence(s).")

actions = [run_small_size_filter, recompute_determination]
actions = [run_small_size_filter, run_class_masking, recompute_determination]

# Order by -id (the indexed primary key) rather than -created_at, which has no
# index and would force a full sort of the table to find the newest page. id
Expand Down Expand Up @@ -850,11 +860,28 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)

run_class_masking = make_post_processing_action(
ClassMaskingTask,
ClassMaskingActionForm,
scope_resolver=lambda collection: {"source_image_collection_id": collection.pk},
name_resolver=lambda task_cls, collection: (
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)
run_rank_rollup = make_post_processing_action(
RankRollupTask,
RankRollupActionForm,
scope_resolver=lambda collection: {"source_image_collection_id": collection.pk},
name_resolver=lambda task_cls, collection: (
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)
actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
run_class_masking,
run_rank_rollup,
]

# Hide images many-to-many field from form. This would list all source images in the database.
Expand Down
21 changes: 21 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,26 @@ class ClassificationPredictionItemSerializer(serializers.Serializer):
logit = serializers.FloatField(read_only=True)


class ClassificationAppliedToSerializer(serializers.ModelSerializer):
"""Lightweight nested representation of the parent classification this was derived from.

Post-processing tasks (class masking, rank rollup) record provenance via
``Classification.applied_to``; this exposes just enough to show what a result
was derived from without recursing back into the full classification.
"""

algorithm = AlgorithmSerializer(read_only=True)

class Meta:
model = Classification
fields = ["id", "created_at", "algorithm"]


class ClassificationSerializer(DefaultSerializer):
taxon = TaxonNestedSerializer(read_only=True)
algorithm = AlgorithmSerializer(read_only=True)
top_n = ClassificationPredictionItemSerializer(many=True, read_only=True)
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
Expand All @@ -957,6 +973,7 @@ class Meta:
"scores",
"logits",
"top_n",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -979,6 +996,8 @@ class Meta(ClassificationSerializer.Meta):


class ClassificationListSerializer(DefaultSerializer):
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
fields = [
Expand All @@ -987,6 +1006,7 @@ class Meta:
"taxon",
"score",
"algorithm",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -1006,6 +1026,7 @@ class Meta:
"score",
"terminal",
"algorithm",
"applied_to",
"created_at",
]

Expand Down
2 changes: 1 addition & 1 deletion ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin):
"""

require_project_for_list = True # Unfiltered list scans are too expensive on this table
queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection")
queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm")
serializer_class = ClassificationSerializer
filterset_fields = [
# Docs about slow loading API browser because of large choice fields
Expand Down
5 changes: 4 additions & 1 deletion ami/main/models_future/occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def _detections_prefetch(*, ordering: tuple[str, ...], with_source_image: bool)
qs = Detection.objects.prefetch_related(
Prefetch(
"classifications",
queryset=Classification.objects.select_related("taxon", "algorithm"),
# applied_to__algorithm: post-processed classifications (class masking,
# rank rollup) serialize their provenance parent; pull it here so the
# nested applied_to render doesn't issue a query per classification.
queryset=Classification.objects.select_related("taxon", "algorithm", "applied_to__algorithm"),
)
).order_by(*ordering)
if with_source_image:
Expand Down
83 changes: 83 additions & 0 deletions ami/ml/management/commands/run_class_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from django.core.management.base import BaseCommand, CommandError

from ami.main.models import SourceImageCollection, TaxaList
from ami.ml.models.algorithm import Algorithm
from ami.ml.post_processing.class_masking import ClassMaskingTask


class Command(BaseCommand):
help = (
"Run class masking post-processing on a source image collection. "
"Masks classifier logits for species not in the given taxa list and recalculates softmax scores."
)

def add_arguments(self, parser):
parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process")
parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask")
parser.add_argument(
"--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask"
)
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes")

def handle(self, *args, **options):
collection_id = options["collection_id"]
taxa_list_id = options["taxa_list_id"]
algorithm_id = options["algorithm_id"]
dry_run = options["dry_run"]

# Validate inputs
try:
collection = SourceImageCollection.objects.get(pk=collection_id)
except SourceImageCollection.DoesNotExist:
raise CommandError(f"SourceImageCollection {collection_id} does not exist.")

try:
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
except TaxaList.DoesNotExist:
raise CommandError(f"TaxaList {taxa_list_id} does not exist.")

try:
algorithm = Algorithm.objects.get(pk=algorithm_id)
except Algorithm.DoesNotExist:
raise CommandError(f"Algorithm {algorithm_id} does not exist.")

if not algorithm.category_map:
raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.")

from ami.main.models import Classification

classification_count = (
Classification.objects.filter(
detection__source_image__collections=collection,
terminal=True,
algorithm=algorithm,
scores__isnull=False,
)
.distinct()
.count()
)

taxa_count = taxa_list.taxa.count()

self.stdout.write(
f"Collection: {collection.name} (id={collection.pk})\n"
f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n"
f"Algorithm: {algorithm.name} (id={algorithm.pk})\n"
f"Classifications to process: {classification_count}"
)

if classification_count == 0:
raise CommandError("No terminal classifications with scores found for this collection/algorithm.")

if dry_run:
self.stdout.write(self.style.WARNING("Dry run — no changes made."))
return

self.stdout.write("Running class masking...")
task = ClassMaskingTask(
source_image_collection_id=collection_id,
taxa_list_id=taxa_list_id,
algorithm_id=algorithm_id,
)
task.run()
self.stdout.write(self.style.SUCCESS("Class masking completed."))
2 changes: 2 additions & 0 deletions ami/ml/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from . import class_masking # noqa: F401
from . import rank_rollup # noqa: F401
from . import small_size_filter # noqa: F401
6 changes: 4 additions & 2 deletions ami/ml/post_processing/admin/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,12 @@ def _render(form: BasePostProcessingActionForm) -> TemplateResponse:
)
return None

# Hand the form the selected rows so it can scope its fields to the
# selection (e.g. only offer algorithms that ran on the chosen occurrence).
if not request.POST.get("confirm"):
return _render(form_class())
return _render(form_class(scope_queryset=queryset))

form = form_class(request.POST)
form = form_class(request.POST, scope_queryset=queryset)
if not form.is_valid():
return _render(form)

Expand Down
66 changes: 66 additions & 0 deletions ami/ml/post_processing/admin/class_masking_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from django import forms

from ami.main.models import Occurrence, TaxaList
from ami.ml.models import Algorithm
from ami.ml.models.algorithm import AlgorithmTaskType
from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm


class ClassMaskingActionForm(BasePostProcessingActionForm):
"""Knobs surfaced when an admin triggers Class masking.

The operator picks the source classifier and the taxa list to keep; the
scope (which collection or occurrence) is supplied by the admin entry point,
not the form. Selections are model instances, so ``to_config`` hands the
schema their primary keys (``ClassMaskingConfig`` expects ``*_id`` ints).
"""

algorithm_id = forms.ModelChoiceField(
queryset=Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value).order_by("name"),
label="Source classifier",
help_text="The classification algorithm whose terminal predictions will be re-scored.",
)
taxa_list_id = forms.ModelChoiceField(
queryset=TaxaList.objects.all().order_by("name"),
label="Taxa list to keep",
help_text=(
"Classes whose taxon is not in this list are masked out; each "
"classification's softmax is renormalised over the classes that remain."
),
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Narrow the classifier dropdown to algorithms that actually produced
# classifications in the selected scope, so the operator cannot pick a
# classifier whose masking would be a no-op for the chosen rows. This is
# only done for an occurrence scope, where the lookup touches the handful
# of classifications under the picked occurrences. A collection scope
# keeps the full classifier list on purpose: the equivalent lookup is an
# unbounded DISTINCT over every classification in the collection (hundreds
# of thousands of rows on a large collection) and can time out while the
# form renders. An over-broad option is harmless — masking a classifier
# that produced nothing in scope changes nothing.
if self.scope_queryset is not None and self.scope_queryset.model is Occurrence:
self.fields["algorithm_id"].queryset = self._algorithms_for_scope(self.scope_queryset)

@staticmethod
def _algorithms_for_scope(scope_queryset):
"""Classification algorithms that produced classifications within the
selected occurrences."""
return (
Algorithm.objects.filter(
task_type=AlgorithmTaskType.CLASSIFICATION.value,
classifications__detection__occurrence__in=scope_queryset,
)
.distinct()
.order_by("name")
)

def to_config(self) -> dict:
return {
"algorithm_id": self.cleaned_data["algorithm_id"].pk,
"taxa_list_id": self.cleaned_data["taxa_list_id"].pk,
}
10 changes: 10 additions & 0 deletions ami/ml/post_processing/admin/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class BasePostProcessingActionForm(forms.Form):
optional fields, derive computed values, rename keys).
"""

def __init__(self, *args, scope_queryset=None, **kwargs):
"""Capture the admin selection the action will run on.

``scope_queryset`` is the queryset of rows the operator picked (e.g. the
chosen occurrences or collections). Subclasses may use it to constrain
their fields to that selection; forms that don't need it ignore it.
"""
self.scope_queryset = scope_queryset
super().__init__(*args, **kwargs)

def to_config(self) -> dict:
"""Return ``cleaned_data`` shaped for ``Job.params['config']``."""
return dict(self.cleaned_data)
13 changes: 13 additions & 0 deletions ami/ml/post_processing/admin/rank_rollup_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm


class RankRollupActionForm(BasePostProcessingActionForm):
"""Knob form for Rank rollup.

Rank rollup runs with the per-rank score thresholds and rollup order defined
on ``RankRollupConfig``. There are no per-run knobs yet, so the form only
confirms the selected capture set(s); the empty ``cleaned_data`` lets the
schema apply its defaults. Threshold overrides can be added here later.
"""
Loading