# ============================================================
# Dual Channel Fluorescence Voxel Counting Workflow - Jython
# Version: 2.0 (Optimised)
# Requires: FIJI / ImageJ
#
# CHANGES FROM v1.0 (all changes preserve scientific output exactly):
#   - PERFORMANCE: Voxel counting now uses ImageProcessor.getHistogram()
#     instead of nested Python pixel loops. ~50-100x faster for typical
#     1024x1024x120 stacks.
#   - PERFORMANCE: ROI cleanup now uses ImageProcessor.fillOutside() which
#     is a Java-native batch operation, replacing the per-pixel Python loop.
#   - ROBUSTNESS: Each sample is now wrapped in try/except so one bad image
#     does not crash the entire batch.
#   - UX: Pre-flight summary shows sample count and missing-channel warnings
#     before processing starts.
#   - UX: Resume capability -- if already-processed samples are detected,
#     user is offered the choice to skip them.
#   - UX: Progress counter [N/total] logged for each sample.
#   - UX: Final batch summary lists counts of single/multi/cleanup/error/skipped.
#   - CODE QUALITY: CSV building extracted into reusable helper functions.
#   - CODE QUALITY: File writes wrapped in try/finally for handle safety.
#
# WORKFLOW:
#   1. Opens paired channel files (_C0, _C1, _C2), sums them into a 32-bit
#      greyscale stack to amplify signal-over-noise
#   2. Despeckles, applies auto brightness/contrast
#   3. Calculates a Default auto threshold from the middle slice
#   4. Applies threshold at 11 sensitivities (70% to 120% in 5% steps)
#   5. Binary Opens to clean cell edges
#   6. Saves a processed TIFF and computes voxel statistics for each threshold
#   7. Displays Z projections side by side for user validation
#   8. User chooses preferred threshold and indicates cell count + cleanup needs
#   9. For multi-cell or cleanup cases, prompts ROI drawing on chosen threshold
#  10. Saves per-sample CSV (all 11 thresholds) and a summary CSV (chosen only)
#
# SCIENTIFIC NOTE: Voxel counting is performed natively in Jython via
# ImageProcessor.getHistogram() rather than through the Voxel Counter plugin.
# This was done to avoid the plugin's "save measurements" prompt that appears
# when called from outside the IJM macro interpreter. The calculations are
# mathematically identical -- both count pixels with value 255 in a binary
# stack and apply voxel volume calibration -- but the histogram-based
# approach is implemented in Java rather than Python and is consequently
# much faster. Output values match the original Voxel Counter plugin.
#
# JYTHON NOTE: This script uses Python 2.7 syntax (Jython inside FIJI).
# Key differences from Python 3:
#   - print is a statement: print "text"
#   - integer division floors by default
#   - strings are bytes, not unicode
# ============================================================

from ij import IJ, WindowManager
from ij.gui import GenericDialog, WaitForUserDialog
from ij.plugin import ZProjector, ImageCalculator
from ij.plugin.frame import RoiManager
from ij.process import AutoThresholder
from java.io import File, FileWriter, BufferedWriter
from loci.plugins import BF
from loci.plugins.in import ImporterOptions
import java.lang.System as System
import os

# ============================================================
# CONFIGURATION CONSTANTS
# ============================================================

SCRIPT_VERSION = "2.0"

# Voxel calibration -- adjust if microscope acquisition settings change
PIXEL_WIDTH  = 0.104       # microns per pixel (X)
PIXEL_HEIGHT = 0.104       # microns per pixel (Y)
VOXEL_DEPTH  = 0.216683    # microns per pixel (Z)
VOXEL_UNIT   = "micron"

# Threshold sensitivity multipliers applied to the auto-calculated threshold
# Lower values = more permissive (catches more signal)
# Higher values = more conservative (catches less signal)
MULTIPLIERS = [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0, 1.05, 1.10, 1.15, 1.20]
MULTIPLIER_LABELS = ["70", "75", "80", "85", "90", "95", "100", "105", "110", "115", "120"]
DEFAULT_CHOSEN_THRESHOLD = "80"

# File naming conventions for input channels
CHANNEL_SUFFIXES = ["_C0", "_C1", "_C2"]

# Voxel Counter output fields, in the order they appear in the per-sample CSV
FIELD_NAMES = [
    "Thresholded voxels",
    "Average voxels per slice",
    "Total ROI Voxels",
    "Volume fraction",
    "Voxels in stack",
    "Voxel size",
    "Thresholded volume",
    "Average volume per slice",
    "Total ROI volume",
    "Volume of stack"
]

# Pixel value used to identify thresholded foreground voxels in binary stacks
# After "Convert to Mask" with background=Dark, foreground = 255
FOREGROUND_VALUE = 255

# ============================================================
# I/O HELPER FUNCTIONS
# ============================================================

def get_directory(title):
    """
    Opens a FIJI directory chooser dialog and returns the selected path.
    The returned path always has a trailing separator for safe concatenation.
    Raises RuntimeError if the user cancels the dialog.
    """
    dc = IJ.getDirectory(title)
    if dc is None:
        raise RuntimeError("No directory selected for: " + title)
    return dc


def make_directory(path):
    """
    Creates a directory at the given path if it does not already exist.
    Uses Java's File.mkdirs() which creates intermediate directories as needed.
    Silent if the directory already exists.
    """
    d = File(path)
    if not d.exists():
        d.mkdirs()


def save_string(content, path):
    """
    Writes a string to a file at the given path, OVERWRITING any existing
    content. Uses Java's BufferedWriter wrapped in try/finally to guarantee
    the file handle is closed even if the write fails.
    """
    bw = BufferedWriter(FileWriter(path, False))
    try:
        bw.write(content)
    finally:
        bw.close()


def append_string(content, path):
    """
    APPENDS a string to an existing file at the given path. Creates the file
    if it does not exist. Uses BufferedWriter wrapped in try/finally to
    guarantee the handle is closed even if the write fails.
    """
    bw = BufferedWriter(FileWriter(path, True))
    try:
        bw.write(content)
    finally:
        bw.close()


# ============================================================
# IMAGE I/O AND DISPLAY HELPER FUNCTIONS
# ============================================================

def open_virtual(path):
    """
    Opens an image as a virtual stack using Bio-Formats. Virtual stacks load
    individual slices from disk on demand rather than loading the entire file
    into RAM upfront, dramatically reducing the initial load delay for large
    z-stacks (~100MB files with 120+ slices).
    The tradeoff is slower per-slice access, but for sequential operations
    like despeckle and thresholding this overhead is negligible compared to
    the load time saved.
    """
    options = ImporterOptions()
    options.setId(path)
    options.setAutoscale(False)
    options.setVirtual(True)
    imps = BF.openImagePlus(options)
    return imps[0]


def z_project_max(imp):
    """
    Generates a Maximum Intensity Z Projection of the given stack. Each pixel
    in the output is the brightest value found at that XY position across all
    Z slices. Used to give a 2D summary view of a 3D stack for visual review.
    """
    zp = ZProjector(imp)
    zp.setMethod(ZProjector.MAX_METHOD)
    zp.doProjection()
    return zp.getProjection()


def close_image(imp):
    """
    Safely closes an ImagePlus without triggering a save prompt. Setting
    imp.changes = False tells FIJI the image has no unsaved modifications,
    suppressing the "Save changes?" dialog that would otherwise appear when
    closing a modified image.
    """
    if imp is not None:
        imp.changes = False
        imp.close()


def close_all_image_windows():
    """
    Closes every currently open image window in FIJI. Called between samples
    to prevent window accumulation across a long batch. Uses close_image()
    so save prompts are suppressed for each window.
    """
    ids = WindowManager.getIDList()
    if ids is None:
        return
    for image_id in ids:
        imp = WindowManager.getImage(image_id)
        close_image(imp)


def apply_voxel_calibration(imp):
    """
    Sets voxel dimensions on the given image using the calibration constants
    defined at the top of this script. Preserves existing channel/slice/frame
    counts so the stack structure is not accidentally collapsed.
    Required before any volumetric calculation so um^3 values are correct.
    """
    n_ch = imp.getNChannels()
    n_sl = imp.getNSlices()
    n_fr = imp.getNFrames()
    IJ.run(imp, "Properties...",
           "channels=" + str(n_ch) +
           " slices=" + str(n_sl) +
           " frames=" + str(n_fr) +
           " pixel_width=" + str(PIXEL_WIDTH) +
           " pixel_height=" + str(PIXEL_HEIGHT) +
           " voxel_depth=" + str(VOXEL_DEPTH) +
           " unit=" + VOXEL_UNIT)


def force_black_white_display(imp):
    """
    Forces a binary image to display correctly as black background, white
    foreground. The pixel values themselves (0 and 255) are not changed --
    only the display LUT and range are set so the user sees what they expect.
    Resets any active threshold colouring left over from setThreshold().
    """
    IJ.run(imp, "Select None", "")
    imp.getProcessor().resetThreshold()
    IJ.run(imp, "Grays", "")
    IJ.run(imp, "Invert LUT", "")
    imp.setDisplayRange(0, 255)
    imp.updateAndDraw()


def show_side_by_side(pre_imp_path, chosen_imp_path, image_name, chosen):
    """
    Opens the pre-threshold reference image and the chosen-threshold processed
    image, generates max projections of each, displays them side by side and
    tiles them. Returns the four ImagePlus objects so the caller can close
    them after the user finishes interacting.
    Returns: (pre_imp, chosen_imp, pre_proj, chosen_proj)
    """
    pre_imp = IJ.openImage(pre_imp_path)
    chosen_imp = IJ.openImage(chosen_imp_path)
    pre_proj = z_project_max(pre_imp)
    chosen_proj = z_project_max(chosen_imp)
    pre_proj.setTitle("Pre-Threshold Reference - " + image_name)
    chosen_proj.setTitle("Chosen Threshold (" + chosen + "%) - " + image_name)
    pre_proj.show()
    chosen_proj.show()
    IJ.run("Tile")
    return (pre_imp, chosen_imp, pre_proj, chosen_proj)


# ============================================================
# FORMATTING HELPER FUNCTIONS
# ============================================================

def format_number(value, digits):
    """
    Formats a float to a fixed number of decimal places using FIJI's
    IJ.d2s() ("double to string"). Used for consistent numeric formatting
    in CSV output that matches the original Voxel Counter plugin format.
    """
    return IJ.d2s(float(value), digits)


def csv_field_name(field_name):
    """
    Returns the CSV column header for a given field, appending units where
    relevant. Volumetric fields get "(um^3)", Voxel size gets "(um)".
    Field names are kept consistent with the original Voxel Counter plugin
    output for compatibility.
    """
    if field_name == "Voxel size":
        return field_name + " (um)"
    if field_name in ("Thresholded volume", "Average volume per slice",
                       "Total ROI volume", "Volume of stack"):
        return field_name + " (um^3)"
    return field_name


# ============================================================
# CORE VOXEL COUNTING (PERFORMANCE-CRITICAL)
# ============================================================

def count_foreground_voxels(stack, n_slices, roi=None):
    """
    Counts pixels equal to FOREGROUND_VALUE (255) across all slices of a stack.
    If a ROI is provided, counting is restricted to pixels within that ROI.

    PERFORMANCE: Uses ImageProcessor.getHistogram() which is a single Java
    call per slice that returns a 256-element int array. Reading hist[255]
    gives the foreground count instantly, replacing the previous nested
    Python pixel loop which made one function call per pixel
    (~125 million calls per stack).

    For typical 1024x1024x120 stacks this is ~50-100x faster than the
    pixel-by-pixel approach while producing mathematically identical results.
    """
    total = 0
    for z in range(1, n_slices + 1):
        ip = stack.getProcessor(z)
        if roi is not None:
            # Setting ROI on the processor restricts getHistogram() to that ROI
            ip.setRoi(roi)
        hist = ip.getHistogram()
        total += hist[FOREGROUND_VALUE]
        if roi is not None:
            ip.resetRoi()
    return total


def count_roi_pixels(roi, image_width, image_height):
    """
    Returns the number of pixels INSIDE a 2D ROI (per slice).

    For rectangular ROIs (mask is None), the count is just bounds.width *
    bounds.height. For freehand or other shaped ROIs, the mask is a
    ByteProcessor where 0 = outside and non-zero = inside; we sum the
    histogram bins from index 1 onwards to count inside pixels.

    PERFORMANCE: Uses Java-native histogram summation instead of nested
    Python loops over every pixel in the ROI bounding box.

    image_width and image_height are passed for the no-ROI case where the
    "ROI" is the entire frame.
    """
    if roi is None:
        return image_width * image_height
    mask = roi.getMask()
    if mask is None:
        # Rectangular ROI -- no mask needed, count is the bounding box area
        bounds = roi.getBounds()
        return bounds.width * bounds.height
    # Freehand or shaped ROI -- sum non-zero histogram bins of the mask
    mask_hist = mask.getHistogram()
    count = 0
    for i in range(1, len(mask_hist)):
        count += mask_hist[i]
    return count


def calculate_voxel_counter_fields(imp):
    """
    Replicates all Voxel Counter plugin output fields by direct calculation
    on the binary stack. Returns a dictionary of field_name -> string value.

    This avoids calling the Voxel Counter plugin entirely, which eliminates
    the Results table save prompt that fires when calling the plugin from
    outside the IJM macro interpreter context.

    The math is identical to the plugin:
      Thresholded voxels = count of pixels with value 255
      Total ROI voxels = roi_pixels_per_slice * n_slices
      Voxels in stack = width * height * n_slices
      Volume fraction (%) = 100 * thresholded / total_roi
      All volumetric fields = voxel_count * voxel_volume_um3

    If a ROI is set on the image, counting is restricted to that ROI
    (per the original plugin behaviour).
    """
    apply_voxel_calibration(imp)

    roi = imp.getRoi()
    width = imp.getWidth()
    height = imp.getHeight()
    n_slices = imp.getStackSize()
    stack = imp.getStack()

    # Count foreground voxels (fast histogram-based)
    thresholded_voxels = count_foreground_voxels(stack, n_slices, roi)

    # Count ROI pixels per slice (constant across all slices)
    roi_pixels_per_slice = count_roi_pixels(roi, width, height)
    total_roi_voxels = roi_pixels_per_slice * n_slices
    voxels_in_stack = width * height * n_slices

    avg_voxels_per_slice = float(thresholded_voxels) / float(n_slices)

    # Volume fraction as percentage of thresholded voxels relative to ROI
    volume_fraction = 0.0
    if total_roi_voxels > 0:
        volume_fraction = 100.0 * float(thresholded_voxels) / float(total_roi_voxels)

    # Get calibration for volumetric calculations
    cal = imp.getCalibration()
    voxel_volume = cal.pixelWidth * cal.pixelHeight * cal.pixelDepth
    unit = cal.getUnit()
    if unit is None or unit == "":
        unit = "pixel"

    values = {}
    values["Thresholded voxels"]       = str(thresholded_voxels)
    values["Average voxels per slice"] = format_number(avg_voxels_per_slice, 3)
    values["Total ROI Voxels"]         = str(total_roi_voxels)
    values["Volume fraction"]          = format_number(volume_fraction, 3)
    values["Voxels in stack"]          = str(voxels_in_stack)
    values["Voxel size"]               = (
        format_number(cal.pixelWidth, 6) + " x " +
        format_number(cal.pixelHeight, 6) + " x " +
        format_number(cal.pixelDepth, 6) + " " + unit
    )
    values["Thresholded volume"]       = format_number(thresholded_voxels * voxel_volume, 6)
    values["Average volume per slice"] = format_number(avg_voxels_per_slice * voxel_volume, 6)
    values["Total ROI volume"]         = format_number(total_roi_voxels * voxel_volume, 6)
    values["Volume of stack"]          = format_number(voxels_in_stack * voxel_volume, 6)

    return values


def run_voxel_counter(imp):
    """
    Convenience wrapper -- runs voxel counting on an image with no ROI applied.
    Kept as a separate function for call-site clarity and parallels to the
    original plugin-based flow.
    """
    return calculate_voxel_counter_fields(imp)


def run_voxel_counter_with_roi(thresh_imp, roi):
    """
    Applies a given ROI to a binary image then runs voxel counting. Returns
    the same dict of field_name -> string value as run_voxel_counter().
    """
    thresh_imp.setRoi(roi)
    return calculate_voxel_counter_fields(thresh_imp)


def clear_outside_roi_stack(imp, roi):
    """
    Sets all pixels OUTSIDE the given ROI to 0, across every slice of the
    stack. Used by the cleanup workflow to remove unwanted regions from the
    binary mask before counting.

    PERFORMANCE: Uses ImageProcessor.fillOutside() which is a Java-native
    batch operation. The previous implementation looped over every pixel
    in Python and called putPixel() one at a time -- this is dramatically
    slower (roughly the same factor as the histogram speedup above).

    After clearing, the LUT is reset to ensure the binary still displays
    correctly as black/white.
    """
    stack = imp.getStack()
    n_slices = imp.getStackSize()
    for z in range(1, n_slices + 1):
        ip = stack.getProcessor(z)
        ip.setColor(0)            # Java-native: fill colour to background
        ip.fillOutside(roi)       # Java-native: zero everything outside ROI
    imp.killRoi()
    force_black_white_display(imp)


# ============================================================
# ROI MANAGEMENT
# ============================================================

def save_roi_to_path(roi, full_path):
    """
    Saves a single ROI to a specific .zip path using the ROI Manager.
    The ROI Manager is reset before and after to avoid polluting any
    existing ROI list the user may have open.
    """
    rm = RoiManager.getInstance()
    if rm is None:
        rm = RoiManager()
    rm.reset()
    rm.addRoi(roi)
    rm.runCommand("Save", full_path)
    rm.reset()


def save_cell_roi(roi, cell_dir, image_name, cell_num):
    """
    Saves a cell ROI with the standard naming convention.
    Wraps save_roi_to_path() for clarity at call sites.
    """
    save_roi_to_path(roi,
                     cell_dir + image_name + "_cell" + str(cell_num) + "_ROI.zip")


def draw_roi_on_image(imp, dialog_title, dialog_message):
    """
    Activates the freehand selection tool, brings the image to the front,
    and prompts the user to draw an ROI via a WaitForUserDialog.

    Uses WaitForUserDialog rather than GenericDialog because the user needs
    to interact with the image window (drawing) BEFORE clicking OK.
    GenericDialog would block input to other windows.

    Returns the drawn ROI, or None if no ROI was drawn.
    """
    imp.show()
    IJ.setTool("freehand")
    IJ.selectWindow(imp.getTitle())
    WaitForUserDialog(dialog_title, dialog_message).show()
    return imp.getRoi()


# ============================================================
# CSV BUILDING HELPERS
# ============================================================

def build_csv_header():
    """
    Returns the standard CSV header row for per-sample CSVs.
    Format: 'Field,70%,75%,...,120%\\n'
    """
    return "Field," + ",".join([l + "%" for l in MULTIPLIER_LABELS]) + "\n"


def build_csv_data_section(all_values):
    """
    Builds the main data rows of the per-sample CSV from the all_values dict
    populated during the threshold loop. Each row is one Voxel Counter field;
    each column is one threshold multiplier.
    """
    lines = []
    for f in FIELD_NAMES:
        cells = [csv_field_name(f)]
        for m in range(len(MULTIPLIERS)):
            cells.append(all_values.get((m, f), "N/A"))
        lines.append(",".join(cells))
    return "\n".join(lines) + "\n"


def build_csv_corrected_section(chosen, parsed_corrected, suffix, header_label, note):
    """
    Builds the ROI-corrected section appended to per-sample CSVs.
    Only the chosen threshold column has corrected values; all other columns
    show N/A since the ROI correction was applied to one threshold only.

    Args:
      chosen           -- chosen multiplier label (e.g. "80")
      parsed_corrected -- dict of corrected field values for chosen threshold
      suffix           -- text added to each field name (e.g. "(ROI corrected)")
      header_label     -- section header text
      note             -- explanatory note placed under the header
    """
    chosen_index = MULTIPLIER_LABELS.index(chosen)
    lines = ["",  # blank separator line
             "--- " + header_label + " ---",
             note]
    for f in FIELD_NAMES:
        cells = [csv_field_name(f) + " " + suffix]
        for m in range(len(MULTIPLIERS)):
            if m == chosen_index:
                cells.append(parsed_corrected.get(f, "N/A"))
            else:
                cells.append("N/A")
        lines.append(",".join(cells))
    return "\n".join(lines) + "\n"


def write_per_sample_csv(path, all_values, corrected_section=""):
    """
    Writes a per-sample CSV containing the header, the original data section
    and (optionally) a corrected section appended at the end.
    """
    content = build_csv_header() + build_csv_data_section(all_values) + corrected_section
    save_string(content, path)


# ============================================================
# CELL PROCESSING (ROI workflows)
# ============================================================

def process_cell_with_roi(roi, cell_num, image_name, sample_dir,
                           chosen, needs_cleanup, summary_file, all_values):
    """
    Processes a single cell ROI applied to the chosen threshold TIFF.
    All eleven threshold columns are preserved in the per-sample CSV using
    the original uncorrected values; an additional corrected section for the
    chosen threshold only is appended at the bottom, clearly labelled.
    The summary CSV records only the ROI-corrected chosen threshold volume.
    """
    cell_dir = sample_dir + "cell_" + str(cell_num) + "/"
    make_directory(cell_dir)
    save_cell_roi(roi, cell_dir, image_name, cell_num)

    # Open chosen threshold TIFF and run native voxel count with ROI applied
    chosen_imp = IJ.openImage(sample_dir + image_name + "_" + chosen + "pct_processed.tif")
    parsed_corrected = run_voxel_counter_with_roi(chosen_imp, roi)
    close_image(chosen_imp)

    # Build CSV with original values + ROI-corrected appendix
    corrected_section = build_csv_corrected_section(
        chosen, parsed_corrected,
        suffix="(ROI corrected)",
        header_label="CHOSEN THRESHOLD ROI CORRECTIONS (" + chosen + "%)",
        note=("NOTE: ROI corrections apply to chosen threshold only. "
              "Other columns are original uncorrected values and may be "
              "inaccurate due to ROI overlap.")
    )
    write_per_sample_csv(
        cell_dir + image_name + "_cell" + str(cell_num) + "_voxels.csv",
        all_values,
        corrected_section
    )

    # Append corrected volume to summary CSV
    corrected_volume = parsed_corrected.get("Thresholded volume", "N/A")
    cleanup_note = " - NEEDS CLEANUP" if needs_cleanup else ""
    append_string(
        image_name + "," + str(cell_num) + "," + chosen + "," + corrected_volume +
        ",ROI segmented from multi-cell image" + cleanup_note + "\n",
        summary_file
    )
    IJ.log("  Recorded: " + image_name + " cell " + str(cell_num) +
           " -- " + chosen + "% (ROI corrected)")


def process_cleanup_single_cell(image_name, sample_dir, chosen,
                                  summary_file, all_values):
    """
    Cleanup workflow for a single cell:
      1. Show pre-threshold and chosen threshold projections side by side
      2. User draws cleanup boundary -- region to KEEP
      3. Apply boundary to chosen threshold stack (zero everything outside)
      4. Show cleaned result and pre-threshold side by side
      5. User draws measurement ROI -- the actual cell region
      6. Run voxel counter restricted to measurement ROI
      7. Save cleaned TIFF, per-sample CSV with corrections, and summary entry
    """
    IJ.log("  Cleanup requested -- prompting cleanup boundary drawing.")

    # Step 1-2: Show side-by-side, prompt for cleanup boundary
    pre_imp, chosen_imp, pre_proj, chosen_proj = show_side_by_side(
        sample_dir + image_name + "_pre_threshold.tif",
        sample_dir + image_name + "_" + chosen + "pct_processed.tif",
        image_name, chosen
    )
    excl_roi = draw_roi_on_image(
        chosen_proj,
        "Draw Cleanup Boundary - " + image_name,
        "Draw a freehand ROI around the analysis area to KEEP.\n" +
        "Everything outside this boundary will be cleared.\n" +
        "Click OK when done."
    )
    close_image(pre_proj)
    close_image(chosen_proj)
    close_image(pre_imp)
    close_image(chosen_imp)

    if excl_roi is None:
        IJ.log("  WARNING: No cleanup boundary drawn -- continuing without cleanup.")
    else:
        save_roi_to_path(excl_roi,
                         sample_dir + image_name + "_cleanup_boundary_ROI.zip")
        IJ.log("  Cleanup boundary ROI saved.")

    # Step 3: Apply cleanup to the chosen threshold stack
    chosen_thresh_imp = IJ.openImage(
        sample_dir + image_name + "_" + chosen + "pct_processed.tif")
    if excl_roi is not None:
        clear_outside_roi_stack(chosen_thresh_imp, excl_roi)

    # Step 4-5: Show cleaned result alongside pre-threshold for measurement ROI
    pre_imp = IJ.openImage(sample_dir + image_name + "_pre_threshold.tif")
    pre_proj = z_project_max(pre_imp)
    cleaned_proj = z_project_max(chosen_thresh_imp)
    pre_proj.setTitle("Pre-Threshold Reference - " + image_name)
    cleaned_proj.setTitle("Cleaned Chosen Threshold (" + chosen + "%) - " + image_name)
    pre_proj.show()
    cleaned_proj.show()
    IJ.run("Tile")

    measure_roi = draw_roi_on_image(
        cleaned_proj,
        "Draw Measurement ROI - " + image_name,
        "Draw a freehand ROI around the cell to MEASURE.\n" +
        "This is separate from the cleanup boundary.\n" +
        "Click OK when done."
    )
    close_image(pre_proj)
    close_image(cleaned_proj)
    close_image(pre_imp)

    # Step 6: Run voxel counter on cleaned stack (with measurement ROI if drawn)
    if measure_roi is not None:
        save_cell_roi(measure_roi, sample_dir, image_name, 1)
        parsed_corrected = run_voxel_counter_with_roi(chosen_thresh_imp, measure_roi)
        IJ.run(chosen_thresh_imp, "Select None", "")
    else:
        IJ.log("  WARNING: No measurement ROI drawn -- counting cleaned full image.")
        parsed_corrected = run_voxel_counter(chosen_thresh_imp)

    # Save cleaned TIFF and close
    IJ.saveAs(chosen_thresh_imp, "Tiff",
              sample_dir + image_name + "_" + chosen + "pct_cleaned.tif")
    close_image(chosen_thresh_imp)

    # Step 7: Write per-sample CSV with original + corrected sections
    corrected_section = build_csv_corrected_section(
        chosen, parsed_corrected,
        suffix="(cleanup + measurement ROI corrected)",
        header_label=("CHOSEN THRESHOLD CLEANUP + MEASUREMENT ROI CORRECTIONS (" +
                      chosen + "%)"),
        note=("NOTE: Cleanup and measurement ROI corrections apply to chosen "
              "threshold only. Other columns are original uncorrected values "
              "and may be inaccurate due to excluded region overlap.")
    )
    write_per_sample_csv(
        sample_dir + image_name + "_voxels.csv",
        all_values,
        corrected_section
    )

    # Append corrected volume to summary CSV
    corrected_volume = parsed_corrected.get("Thresholded volume", "N/A")
    append_string(
        image_name + ",1," + chosen + "," + corrected_volume +
        ",Single cell confirmed - cleanup boundary and measurement ROI applied\n",
        summary_file
    )
    IJ.log("  Recorded: " + chosen + "% with cleanup and measurement ROI correction.")


# ============================================================
# THRESHOLD LOOP (called once per sample)
# ============================================================

def run_threshold_loop(merged_imp, image_name, sample_dir, lower):
    """
    For each of the 11 threshold multipliers:
      - Duplicate the merged stack
      - Apply threshold = lower * multiplier
      - Convert to binary mask
      - Binary Open to clean stray pixels
      - Apply voxel calibration
      - Generate Z projection for visual review
      - Save processed TIFF
      - Run voxel counter, store results

    Returns: (proj_list, all_values, thresholded_volumes)
      - proj_list: list of projection ImagePlus objects (caller must close)
      - all_values: dict keyed by (multiplier_index, field_name)
      - thresholded_volumes: list of "Thresholded volume" strings, one per multiplier
    """
    proj_list = []
    all_values = {}
    thresholded_volumes = []

    for m, (multiplier, label) in enumerate(zip(MULTIPLIERS, MULTIPLIER_LABELS)):

        dup_imp = merged_imp.duplicate()

        # Threshold and binarise
        IJ.setThreshold(dup_imp, lower * multiplier, 65535)
        IJ.run(dup_imp, "Convert to Mask", "method=Default background=Dark black")

        # Morphological Open to clean cell edges and remove stray pixels
        IJ.run(dup_imp, "Open", "stack")

        # Ensure correct binary display and apply voxel calibration
        force_black_white_display(dup_imp)
        apply_voxel_calibration(dup_imp)

        # Z projection for visual review
        proj = z_project_max(dup_imp)
        force_black_white_display(proj)
        proj.setTitle("Threshold " + label + "% - " + image_name)
        proj.show()
        proj_list.append(proj)

        # Save processed TIFF for QC and ROI workflows
        IJ.saveAs(dup_imp, "Tiff",
                  sample_dir + image_name + "_" + label + "pct_processed.tif")

        # Voxel counting (fast histogram-based, no plugin call, no save prompt)
        parsed = run_voxel_counter(dup_imp)
        for f in FIELD_NAMES:
            all_values[(m, f)] = parsed[f]
        thresholded_volumes.append(parsed.get("Thresholded volume", "N/A"))

        close_image(dup_imp)

    return (proj_list, all_values, thresholded_volumes)


# ============================================================
# SAMPLE PREPROCESSING
# ============================================================

def open_and_merge_channels(input_dir, image_name, ext, n_channels):
    """
    Opens all channel files for a sample as virtual stacks and sums them
    into a single 32-bit greyscale stack. Summing amplifies genuine signal
    (where multiple channels are bright) relative to background noise (where
    only one channel contributes), giving a higher signal-to-noise ratio
    before thresholding.

    Returns the merged ImagePlus. Caller is responsible for closing it.
    """
    ic = ImageCalculator()
    merged_imp = open_virtual(input_dir + image_name + "_C0" + ext)
    for c in range(1, n_channels):
        next_imp = open_virtual(input_dir + image_name + CHANNEL_SUFFIXES[c] + ext)
        result_imp = ic.run("Add create 32-bit stack", merged_imp, next_imp)
        close_image(merged_imp)
        close_image(next_imp)
        merged_imp = result_imp
    return merged_imp


def preprocess_merged(merged_imp, image_name, sample_dir):
    """
    Standard preprocessing applied to every sample after channel merging:
      1. Despeckle (median filter, removes salt-and-pepper noise)
      2. Auto Brightness/Contrast (saturated 0.35%)
      3. Save pre-threshold TIFF for reference
      4. Show pre-threshold Z projection for the user

    Returns the pre-threshold projection ImagePlus (caller must close).
    """
    IJ.run(merged_imp, "Despeckle", "stack")
    IJ.run(merged_imp, "Enhance Contrast", "saturated=0.35")
    IJ.saveAs(merged_imp, "Tiff",
              sample_dir + image_name + "_pre_threshold.tif")
    pre_thresh_proj = z_project_max(merged_imp)
    pre_thresh_proj.setTitle("Pre-Threshold - " + image_name)
    pre_thresh_proj.show()
    return pre_thresh_proj


def calculate_middle_slice_threshold(merged_imp):
    """
    Calculates a Default auto-threshold value from the middle slice of the
    stack. The middle slice is typically the most in-focus and is therefore
    a reasonable representative for choosing the threshold lower bound.
    Returns the lower threshold integer value (0-65535 for 16-bit images).
    """
    stack = merged_imp.getStack()
    n_slices = merged_imp.getNSlices()
    middle_slice = int(round(n_slices / 2.0))
    middle_proc = stack.getProcessor(middle_slice)
    thresholder = AutoThresholder()
    histogram = middle_proc.getHistogram()
    return thresholder.getThreshold("Default", histogram)


# ============================================================
# PRE-FLIGHT CHECK
# ============================================================

def list_samples(input_dir, n_channels):
    """
    Scans the input directory for sample sets. A "sample" is any _C0 file
    whose corresponding _C1 (and _C2 if 3-channel) files all exist.

    Returns: (complete_samples, incomplete_samples)
      complete_samples   -- list of (image_name, ext) tuples ready to process
      incomplete_samples -- list of image_name strings missing channel files
    """
    complete = []
    incomplete = []
    for filename in sorted(os.listdir(input_dir)):
        if not (filename.endswith("_C0.tif") or filename.endswith("_C0.tiff")):
            continue
        if filename.endswith("_C0.tif"):
            image_name = filename.replace("_C0.tif", "")
            ext = ".tif"
        else:
            image_name = filename.replace("_C0.tiff", "")
            ext = ".tiff"

        all_present = True
        for c in range(n_channels):
            channel_path = input_dir + image_name + CHANNEL_SUFFIXES[c] + ext
            if not os.path.exists(channel_path):
                all_present = False
                break

        if all_present:
            complete.append((image_name, ext))
        else:
            incomplete.append(image_name)
    return (complete, incomplete)


def find_already_processed(complete_samples, output_dir):
    """
    Returns a set of image_names whose per-sample CSV already exists in the
    output directory, indicating they were processed in a previous run.
    """
    already = set()
    for (image_name, _) in complete_samples:
        csv_path = output_dir + image_name + "/" + image_name + "_voxels.csv"
        if os.path.exists(csv_path):
            already.add(image_name)
    return already


def show_preflight_dialog(complete_samples, incomplete_samples, already_processed):
    """
    Displays the pre-flight summary dialog and returns (proceed, resume_mode).
      proceed     -- True if user clicked OK, False if cancelled
      resume_mode -- True if user opted to skip already-processed samples

    The dialog reports how many samples were found, how many are incomplete
    (missing channel files), and how many have already been processed.
    """
    n_complete = len(complete_samples)
    n_incomplete = len(incomplete_samples)
    n_already = len(already_processed)

    gd = GenericDialog("Batch Pre-flight Summary")
    gd.addMessage("Cell Volume Workflow v" + SCRIPT_VERSION)
    gd.addMessage("")
    gd.addMessage("Found " + str(n_complete) + " complete samples ready to process.")
    if n_incomplete > 0:
        gd.addMessage(str(n_incomplete) + " samples missing channel files (will be skipped).")
    if n_already > 0:
        gd.addMessage(str(n_already) + " samples appear to be already processed.")
        gd.addCheckbox("Resume mode (skip already-processed samples)", True)
    gd.addMessage("")
    gd.addMessage("Click OK to begin batch, Cancel to abort.")
    gd.showDialog()

    if gd.wasCanceled():
        return (False, False)

    resume = False
    if n_already > 0:
        resume = gd.getNextBoolean()
    return (True, resume)


# ============================================================
# MAIN WORKFLOW
# ============================================================

IJ.log("=" * 60)
IJ.log("Cell Volume Voxel Counting Workflow v" + SCRIPT_VERSION)
IJ.log("=" * 60)

# --- User input: directories and channel count ---
input_dir = get_directory("Input_folder containing image stacks")
output_dir = get_directory("Output_folder to save results")

gd = GenericDialog("Channel Settings")
gd.addMessage("How many channels per sample?")
gd.addChoice("Number of channels:", ["2", "3"], "2")
gd.showDialog()
if gd.wasCanceled():
    raise RuntimeError("Macro cancelled by user.")
n_channels = int(gd.getNextChoice())

# --- Pre-flight check ---
complete_samples, incomplete_samples = list_samples(input_dir, n_channels)

if len(complete_samples) == 0:
    IJ.error("No complete samples found in input folder. " +
             "Check that _C0 and _C1" + (" and _C2" if n_channels == 3 else "") +
             " files are present.")
    raise RuntimeError("No samples to process.")

already_processed = find_already_processed(complete_samples, output_dir)
proceed, resume_mode = show_preflight_dialog(
    complete_samples, incomplete_samples, already_processed)
if not proceed:
    raise RuntimeError("Batch cancelled at pre-flight.")

# --- Initialise summary CSV (only if not resuming, to preserve existing data) ---
summary_file = output_dir + "threshold_selections.csv"
if not (resume_mode and os.path.exists(summary_file)):
    save_string(
        "Sample,Cell,Chosen Threshold (%),Thresholded Volume (um^3),Notes\n",
        summary_file)

# --- Log incomplete samples upfront for transparency ---
for name in incomplete_samples:
    IJ.log("WARNING: Skipping " + name + " -- missing channel file(s).")

# --- Batch counters for final summary ---
total_to_process = len(complete_samples)
counter_processed = 0
counter_single = 0
counter_single_cleanup = 0
counter_multi = 0
counter_skipped = 0
counter_errors = 0
error_log = []

# ============================================================
# MAIN PROCESSING LOOP
# ============================================================

for sample_index, (image_name, ext) in enumerate(complete_samples):

    counter_processed += 1
    progress_prefix = "[" + str(counter_processed) + "/" + str(total_to_process) + "]"

    # --- Skip already-processed samples in resume mode ---
    if resume_mode and image_name in already_processed:
        IJ.log(progress_prefix + " SKIP (already processed): " + image_name)
        counter_skipped += 1
        continue

    # --- Process sample, with try/except so a bad sample doesn't kill the batch ---
    try:
        # Encourage GC between samples to free memory from the previous sample
        # without waiting for FIJI's automatic GC to trigger mid-processing
        System.gc()

        IJ.log(progress_prefix + " Processing: " + image_name)
        sample_dir = output_dir + image_name + "/"
        make_directory(sample_dir)

        # === SAMPLE PIPELINE ===

        # 1. Open and sum all channels into a 32-bit greyscale stack
        merged_imp = open_and_merge_channels(input_dir, image_name, ext, n_channels)

        # 2. Despeckle, contrast-enhance, save pre-threshold TIFF, show projection
        pre_thresh_proj = preprocess_merged(merged_imp, image_name, sample_dir)

        # 3. Calculate threshold from middle slice
        lower = calculate_middle_slice_threshold(merged_imp)

        # 4. Run threshold loop (11 multipliers)
        proj_list, all_values, thresholded_volumes = run_threshold_loop(
            merged_imp, image_name, sample_dir, lower)

        # 5. Tile projections for side-by-side review
        IJ.run("Tile")

        # === USER VALIDATION DIALOG ===
        gd2 = GenericDialog("Isolated Cell Check: " + image_name)
        gd2.addMessage("Review the projections -- leftmost is pre-threshold, " +
                       "followed by 70% through 120%.")
        gd2.addMessage("How many cells are visible?")
        gd2.addChoice("Number of cells:", ["1", "2", "3"], "1")
        gd2.addCheckbox("Needs Cleanup?", False)
        gd2.addMessage("Which threshold do you prefer?")
        gd2.addChoice("Chosen threshold:", MULTIPLIER_LABELS, DEFAULT_CHOSEN_THRESHOLD)
        gd2.showDialog()
        if gd2.wasCanceled():
            raise RuntimeError("Macro cancelled by user at: " + image_name)
        cell_count    = int(gd2.getNextChoice())
        needs_cleanup = gd2.getNextBoolean()
        chosen        = gd2.getNextChoice()

        # Close all projections before further processing
        for proj in proj_list:
            close_image(proj)
        close_image(pre_thresh_proj)

        # === BRANCH ON USER SELECTION ===

        if cell_count == 1 and not needs_cleanup:
            # Single cell, no cleanup -- straightforward save
            write_per_sample_csv(
                sample_dir + image_name + "_voxels.csv",
                all_values
            )
            chosen_index  = MULTIPLIER_LABELS.index(chosen)
            chosen_volume = thresholded_volumes[chosen_index]
            append_string(
                image_name + ",1," + chosen + "," + chosen_volume +
                ",Single cell confirmed\n",
                summary_file
            )
            IJ.log("  Recorded: " + chosen + "%")
            counter_single += 1

        elif cell_count == 1 and needs_cleanup:
            # Single cell with cleanup -- delegate to helper
            process_cleanup_single_cell(
                image_name, sample_dir, chosen,
                summary_file, all_values
            )
            counter_single_cleanup += 1

        else:
            # Multiple cells -- ROI segmentation per cell
            IJ.log("  *** MULTIPLE CELLS DETECTED -- Prompting ROI segmentation ***")

            # Optional global cleanup boundary applied to chosen threshold first
            if needs_cleanup:
                IJ.log("  Cleanup requested -- prompting cleanup boundary " +
                       "before cell segmentation.")
                pre_imp, chosen_imp, pre_proj, chosen_proj = show_side_by_side(
                    sample_dir + image_name + "_pre_threshold.tif",
                    sample_dir + image_name + "_" + chosen + "pct_processed.tif",
                    image_name, chosen
                )
                excl_roi = draw_roi_on_image(
                    chosen_proj,
                    "Draw Cleanup Boundary - " + image_name,
                    "Draw a freehand ROI around the analysis area to KEEP " +
                    "for all cells.\n" +
                    "Everything outside this boundary will be cleared.\n" +
                    "Click OK when done."
                )
                close_image(pre_proj)
                close_image(chosen_proj)
                close_image(pre_imp)
                close_image(chosen_imp)

                if excl_roi is None:
                    IJ.log("  WARNING: No cleanup boundary drawn -- continuing without.")
                else:
                    save_roi_to_path(excl_roi,
                                     sample_dir + image_name + "_cleanup_boundary_ROI.zip")
                    IJ.log("  Cleanup boundary ROI saved.")
                    # Apply cleanup to chosen threshold TIFF in-place
                    chosen_thresh_imp = IJ.openImage(
                        sample_dir + image_name + "_" + chosen + "pct_processed.tif")
                    clear_outside_roi_stack(chosen_thresh_imp, excl_roi)
                    IJ.saveAs(chosen_thresh_imp, "Tiff",
                              sample_dir + image_name + "_" + chosen +
                              "pct_processed.tif")
                    close_image(chosen_thresh_imp)
                    IJ.log("  Cleanup boundary applied to chosen threshold TIFF.")

            # Per-cell ROI loop
            for cell_num in range(1, cell_count + 1):
                pre_imp, chosen_imp, pre_proj, chosen_proj = show_side_by_side(
                    sample_dir + image_name + "_pre_threshold.tif",
                    sample_dir + image_name + "_" + chosen + "pct_processed.tif",
                    image_name, chosen
                )
                roi = draw_roi_on_image(
                    chosen_proj,
                    "Draw ROI - Cell " + str(cell_num) + " of " + str(cell_count),
                    "Draw a freehand ROI around cell " + str(cell_num) + " of " +
                    str(cell_count) + ".\nClick OK when done."
                )
                close_image(pre_proj)
                close_image(chosen_proj)
                close_image(pre_imp)
                close_image(chosen_imp)

                if roi is None:
                    IJ.log("  ROI drawing cancelled for cell " + str(cell_num) +
                           " -- skipping.")
                    continue

                process_cell_with_roi(
                    roi, cell_num, image_name, sample_dir,
                    chosen, needs_cleanup, summary_file, all_values
                )
            counter_multi += 1

        close_image(merged_imp)
        close_all_image_windows()

    except Exception as e:
        # Per-sample error handling: log, record in summary, clean up, continue
        # Without this, one bad image (corrupt file, OOM, etc.) would crash
        # the entire overnight batch
        err_msg = str(e)
        IJ.log("  *** ERROR: " + err_msg)
        error_log.append((image_name, err_msg))
        counter_errors += 1
        try:
            append_string(
                image_name + ",,,,ERROR: " + err_msg.replace(",", ";") + "\n",
                summary_file
            )
        except:
            pass
        # Best-effort cleanup of any open windows from the failed sample
        try:
            close_all_image_windows()
        except:
            pass
        # Force GC to recover memory before next sample
        System.gc()

# ============================================================
# FINAL BATCH SUMMARY
# ============================================================

IJ.log("")
IJ.log("=" * 60)
IJ.log("BATCH COMPLETE")
IJ.log("=" * 60)
IJ.log("Total samples seen:       " + str(total_to_process))
IJ.log("  Single cell (no cleanup): " + str(counter_single))
IJ.log("  Single cell (cleanup):    " + str(counter_single_cleanup))
IJ.log("  Multi-cell:               " + str(counter_multi))
IJ.log("  Skipped (already done):   " + str(counter_skipped))
IJ.log("  Errors:                   " + str(counter_errors))
if len(incomplete_samples) > 0:
    IJ.log("  Skipped (missing files):  " + str(len(incomplete_samples)))
IJ.log("")
if counter_errors > 0:
    IJ.log("ERROR DETAILS:")
    for (name, msg) in error_log:
        IJ.log("  " + name + ": " + msg)
    IJ.log("")
IJ.log("Results saved to:    " + output_dir)
IJ.log("Summary CSV:         " + summary_file)
IJ.log("=" * 60)
