71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Merge split CSVs into a single file with split source labels."""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
|
|
# Mapping of split CSV filenames to numeric labels for the source column
|
|
SPLIT_LABELS = {
|
|
"split_test.csv": 1,
|
|
"split_train.csv": 2,
|
|
"split_val.csv": 3,
|
|
}
|
|
|
|
DEFAULT_COLUMN_NAME = "split_source"
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
repo_root = Path(__file__).resolve().parent.parent
|
|
parser = argparse.ArgumentParser(
|
|
description=
|
|
"Combine split_*.csv files from splits_v2 and label their origin with integers."
|
|
)
|
|
parser.add_argument(
|
|
"--input-dir",
|
|
type=Path,
|
|
default=repo_root / "splits_v2",
|
|
help="Directory containing split_*.csv files (default: %(default)s)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=Path,
|
|
default=repo_root / "data" / "merged_splits.csv",
|
|
help="Destination CSV path (default: %(default)s)",
|
|
)
|
|
parser.add_argument(
|
|
"--column-name",
|
|
default=DEFAULT_COLUMN_NAME,
|
|
help="Name for the source column (default: %(default)s)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
if not args.input_dir.is_dir():
|
|
raise SystemExit(f"Input directory not found: {args.input_dir}")
|
|
|
|
frames = []
|
|
for filename, label in SPLIT_LABELS.items():
|
|
csv_path = args.input_dir / filename
|
|
if not csv_path.is_file():
|
|
raise SystemExit(f"Missing expected split file: {csv_path}")
|
|
df = pd.read_csv(csv_path)
|
|
df[args.column_name] = label
|
|
frames.append(df)
|
|
|
|
if not frames:
|
|
raise SystemExit("No split CSV files were loaded.")
|
|
|
|
merged = pd.concat(frames, ignore_index=True)
|
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
merged.to_csv(args.output, index=False)
|
|
print(f"Merged {len(frames)} files with {len(merged)} rows into {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|