-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcreate_cnn_splits.py
More file actions
62 lines (56 loc) · 2.23 KB
/
create_cnn_splits.py
File metadata and controls
62 lines (56 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Helper script to generate CNN image-based stratified splits.
Run from inside the looted_site_detection directory without needing the package import path.
Example:
python create_cnn_splits.py \
--data_root ../change_detection/planet_mosaics_final_4bands/datasets \
--year 2023 \
--output tmp_cnn_splits.json
"""
import argparse
import json
import os
import sys
from pathlib import Path
if __package__ is None or __package__ == '':
pkg_parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if pkg_parent not in sys.path:
sys.path.append(pkg_parent)
from looted_site_detection.dynamic_split_images import create_image_based_splits
else:
from .dynamic_split_images import create_image_based_splits
def parse_args():
p = argparse.ArgumentParser(description="Generate year-filtered CNN splits")
p.add_argument('--data_root', type=str, required=True,
help='Root directory containing looted/ and preserved/ subdirectories')
p.add_argument('--year', type=int, default=2023,
help='Year to filter images (default 2023)')
p.add_argument('--test_size', type=float, default=0.2,
help='Fraction for test set (default 0.2)')
p.add_argument('--val_size', type=float, default=0.1,
help='Fraction for validation set (default 0.1)')
p.add_argument('--seed', type=int, default=42,
help='Random seed (default 42)')
p.add_argument('--output', type=str, default='tmp_cnn_splits.json',
help='Output JSON file (default tmp_cnn_splits.json)')
return p.parse_args()
def main():
args = parse_args()
splits = create_image_based_splits(
data_root=args.data_root,
labels_csv=None,
test_size=args.test_size,
val_size=args.val_size,
year=args.year,
seed=args.seed
)
out_path = Path(args.output)
out_path.write_text(json.dumps(splits, indent=2))
print(f"Saved splits to {out_path}:")
for k in ['train','val','test']:
print(f" {k}: {len(splits[k])} sites")
if __name__ == '__main__':
main()