This Jupyter Notebook serves as the core development and training
environment for the object detection component of the MORTISCOPE
system. The primary objective is to train a
YOLOv8 model to accurately detect and classify the
different life stages of Chrysomya megacephala. The
successful training and evaluation of this model are fundamental to
the application’s ability to process images and provide accurate
entomological data for PMI estimation.
Workflow Summary
This notebook follows a pipeline designed not just for training, but
also for deployment and preliminary analysis of forensic imagery:
Environment & Storage Initialization: Mounts
Google Drive for persistent artifact storage and installs specific
dependencies, including sahi for advanced inference.
Dynamic Dataset Aggregation: Downloads multiple
versioned datasets (varying in resolution and environmental
conditions) from Roboflow and programmatically builds a unified
data.yaml configuration file to prevent path errors.
Custom Architecture Definition (P2): Unlike
standard training, this step explicitly defines a
YOLOv8-P2 architecture. This variant adds a
high-resolution detection head (P2) to the model to potentially
improve the ability to detect tiny objects like early-stage
instars. It uses transfer learning by mapping standard YOLOv8
weights to this custom architecture.
Training Pipeline: Features a smart resume system
that automatically detects interruptions and resumes training from
the last saved checkpoint in Google Drive, alongside a custom
callback to archive weights for every epoch.
Comprehensive Evaluation: Generates training loss
plots, a normalized confusion matrix for biological stage
differentiation, and detailed inference speed benchmarks.
Advanced Inference Setup: Integrates
SAHI (Sliced Aided Hyper Inference) to handle
high-resolution field imagery to allow the detection of small
insect instars that standard resizing would miss.
Deployment & Demonstration: Exports the final
model to ONNX format for edge deployment and
creates an interactive inference engine. This engine detects scene
types (Macro vs. Field) to automatically switch between standard
and sliced inference modes, applies outlier filtering, and
generates publication-ready visualizations.
Tech Stack
Deep Learning and Inference Engines
YOLOv8-P2 (P2 Head Variant): The standard YOLO
architecture downsamples images significantly up to 32x, often
causing small objects to vanish. This notebook uses a
P2 variant, which includes an additional
detection head at a higher resolution (4x downsampling). This
preserves fine-grained texture details essential for
distinguishing between instars.
SAHI (Sliced Aided Hyper Inference): An
advanced library used to perform inference on high-resolution
images by slicing them into smaller overlapping windows. This is
critical for detecting small instars in wide-angle “field”
shots.
PyTorch: The underlying tensor computation
framework powering YOLOv8.
Transfer Learning: A technique to load weights
from the standard yolov8l.pt model into the custom
P2 backbone, speeding up convergence.
ONNX (Open Neural Network Exchange): Used to
export the trained PyTorch model into a standardized format to
optimize it for deployment on various hardware platforms without
Python dependencies.
Data Management and Processing
Roboflow: Used for versioning, hosting,
preprocessing, augmenting and downloading the multi-part
Chrysomya megacephala dataset.
Google Drive: Serves as the persistent storage
layer to ensure trained weights (best.pt,
last.pt) and historical logs are preserved across
Google Colab sessions.
Pandas: Used to parse and manipulate the
training logs (results.csv) for custom performance
plotting.
PyYAML: Automates the creation of the dataset
configuration file, dynamically linking downloaded image paths
to the model.
Visualization and Analysis
OpenCV (cv2): Handles image I/O,
color space conversion, and drawing bounding boxes for the
inference demonstration.
Seaborn & Matplotlib: Used to generate
high-quality statistical plots, including the smoothed loss
curves and the heatmap-style confusion matrix.
NumPy: Performs essential array operations,
particularly for calculating Intersection over Union (IoU) and
filtering statistical outliers in prediction areas.
Section 1: Project Initialization and Dependencies
Purpose: This section establishes the file system
structure required for the project and installs the specific
external libraries needed for computer vision tasks. It ensures
the notebook has access to persistent storage (Google Drive) and
the necessary tools for dataset management and inference.
Key Activities:
Drive Mounting: Connects to Google Drive to
create a persistent directory. This is crucial for saving
trained weights and logs so they are not lost when the Colab
session disconnects.
Library Installation: Installs
roboflow (for downloading datasets),
ultralytics (for the YOLO model), and
sahi (for Sliced Aided Hyper Inference, used for
detecting small objects).
Credential Management: Securely retrieves the
Roboflow API key from Colab’s secrets manager to allow
authorized access to the private datasets.
import osfrom google.colab import drive# Mounts Google Drive to allow the Colab notebook to access and save files directly.drive.mount('/content/drive')# Defines configuration variables for the main project and the specific model folder.root_folder ="Mortiscope Models"model_name ="YOLOv8"# Constructs the full, platform-independent path to the model's project directory.project_path = os.path.join('/content/drive/MyDrive', root_folder, model_name)# Constructs the path for a dedicated subdirectory to store model weights.weights_path = os.path.join(project_path, 'weights')# Creates the project and weights directories.os.makedirs(project_path, exist_ok=True)os.makedirs(weights_path, exist_ok=True)# Prints the fully constructed paths to the console for user confirmation.print(f"Project Directory: {project_path}")print(f"Weight Storage: {weights_path}")
# Executes the shell command to install the specified libraries using pip.!pip install roboflow ultralytics sahi shapely# Prints a confirmation message to the console after the installation command completes.print("Libraries installed successfully.")
from google.colab import userdatafrom roboflow import Roboflowtry:# Attempts to retrieve the 'ROBOFLOW_API_KEY' from Google Colab's secret manager. rf_api_key = userdata.get('ROBOFLOW_API_KEY')print("API Key retrieved successfully.")exceptExceptionas e:# Handles potential errors, such as the secret not being defined in the Colab environment.print("Error: Could not retrieve key.")# Re-raises the exception to halt execution if the key is essential.raise e# Initializes the Roboflow client with the successfully retrieved API key.rf = Roboflow(api_key=rf_api_key)
API Key retrieved successfully.
Section 2: Dataset Acquisition and Configuration
Purpose: To aggregate data from multiple sources
and format it for the YOLOv8 training pipeline. Since the dataset
is versioned across different projects (representing different
resolutions and environmental conditions), this section
consolidates them into a single training configuration.
Key Activities:
Dataset Download: Uses the Roboflow API to
download specific versions of the
Chrysomya megacephala datasets.
Dynamic YAML Generation: Instead of manually
creating a configuration file, the script iterates through the
downloaded folders to verify valid paths. It then
programmatically generates a data.yaml file to
ensure that the model only attempts to train on data that
actually exists locally, preventing “path not found” errors.
# Accesses the specific Roboflow workspace that contains all the project datasets.workspace = rf.workspace("mortiscope-fvkhd")# Downloads version 1 of several distinct projects from the workspace.decomposition_1 = workspace.project("decomposition-high-resolution-300").version(1).download("yolov8")decomposition_2 = workspace.project("decomposition-high-resolution-250").version(1).download("yolov8")decomposition_3 = workspace.project("decomposition-high-resolution-200").version(1).download("yolov8")decomposition_4 = workspace.project("decomposition-standard-resolution-300").version(1).download("yolov8")decomposition_5 = workspace.project("decomposition-standard-resolution-250").version(1).download("yolov8")complementary = workspace.project("complementary").version(1).download("yolov8")# Prints a confirmation message to the console after all download operations have successfully completed.print("\nAll datasets downloaded successfully.")
import yaml# A list containing all the dataset objects that were previously downloaded.all_datasets = [ decomposition_1, decomposition_2, decomposition_3, decomposition_4, decomposition_5, complementary]# Initializes lists to store the file paths to the training and validation image folders.train_paths = []val_paths = []print("-"*70)print("Building Dataset Configuration")# Iterates through each dataset to locate and collect the paths to its image directories.for ds in all_datasets:# Constructs the expected paths for the 'train' and 'valid' image subdirectories. t_path = os.path.join(ds.location, 'train', 'images') v_path = os.path.join(ds.location, 'valid', 'images')# Verifies that the training directory actually exists before adding it to the list.if os.path.exists(t_path): train_paths.append(t_path)print(f"Added train: {ds.location.split('/')[-1]}")else:print(f"Skipped train (Empty): {ds.location.split('/')[-1]}")# Verifies that the validation directory actually exists before adding it.if os.path.exists(v_path): val_paths.append(v_path)print(f"Added valid: {ds.location.split('/')[-1]}")else:print(f"Skipped valid (Not found): {ds.location.split('/')[-1]}")# Defines the master configuration dictionary in the format required by Ultralytics YOLO.data_config = {'names': {0: 'adult',1: 'instar_1',2: 'instar_2',3: 'instar_3',4: 'pupa' },'nc': 5, # Number of classes.'train': train_paths,'val': val_paths,}# Defines the output path for the YAML file in the current working directory.yaml_path = os.path.join(os.getcwd(), 'data.yaml')# Writes the configuration dictionary to the 'data.yaml' file.withopen(yaml_path, 'w') as outfile: yaml.dump(data_config, outfile, default_flow_style=False)# Prints a confirmation summary, showing the location of the file and the total number of included data folders.print("\n"+"-"*70)print(f"Balanced configuration created at: {yaml_path}")print(f"Total Train Folders: {len(train_paths)}")print(f"Total Valid Folders: {len(val_paths)}")print("-"*70)
Purpose: Standard YOLO models struggle with tiny
objects. Here, the YOLOv8-P2 architecture was
manually defined which adds a high-resolution detection layer, and
prepare it for training.
Key Activities:
Architecture Definition: A raw string
containing the YAML configuration for yolov8-p2 is
defined and written to disk. This config adds a detection head
at stride 4 (P2), keeping high-resolution features.
Weight Transfer: The standard
yolov8l.pt weights are downloaded. The
model.load() function then transfers the matching
backbone weights into our custom P2 architecture, allowing us to
benefit from pre-training even on a custom structure.
Callback Registration: Attaches the custom
on_train_epoch_end function to save history to
Drive.
import shutil# Defines and creates a dedicated directory for storing a snapshot of the model weights at the end of each training epoch.history_path = os.path.join(weights_path, 'epoch_history')os.makedirs(history_path, exist_ok=True)def on_train_epoch_end(trainer):""" A callback function executed at the end of each training epoch. This function performs the following actions: 1. Saves a copy of the latest model weights (`last.pt`) to a historical archive, named with the corresponding epoch number. 2. Updates the primary `last.pt` file in the persistent Google Drive weights folder, allowing for training resumption. 3. Updates the `best.pt` file in the persistent folder whenever the trainer identifies a new best-performing model. Args: trainer: The Ultralytics trainer object, which provides access to the current training state, including epoch number and file paths. """# Gets the current epoch number. current_epoch = trainer.epoch +1# Defines the source paths for the weights generated by the trainer in the temporary, session-specific output directory. local_last = os.path.join(trainer.save_dir, 'weights', 'last.pt') local_best = os.path.join(trainer.save_dir, 'weights', 'best.pt')# Checks for the existence of the latest epoch's weights before proceeding.if os.path.exists(local_last):# Creates a unique filename for the historical weight file. history_filename =f"{current_epoch:03d}_epoch.pt" history_dest = os.path.join(history_path, history_filename)# Copies the latest weights to the historical archive directory. shutil.copy(local_last, history_dest)print(f" History Saved: {history_filename}")# Overwrites the main 'last.pt' file in the persistent Google Drive folder. resume_dest = os.path.join(weights_path, 'last.pt') shutil.copy(local_last, resume_dest)# Checks if the trainer has produced a new best-performing model weight file.if os.path.exists(local_best):# Overwrites the main 'best.pt' file in the persistent Google Drive folder. best_dest = os.path.join(weights_path, 'best.pt') shutil.copy(local_best, best_dest)# Prints a confirmation message detailing the callback's configuration and the locations where files will be saved.print("Callback Defined: ")print(f" 1. History saved to: {history_path}")print(f" 2. Resume file active at: {weights_path}/last.pt")
Callback Defined:
1. History saved to: /content/drive/MyDrive/Mortiscope Models/YOLOv8/weights/epoch_history
2. Resume file active at: /content/drive/MyDrive/Mortiscope Models/YOLOv8/weights/last.pt
Purpose: This is the computational core of the
notebook. It defines the hyperparameters that control how the
model learns and executes the training loop.
Key Activities:
Hyperparameter Definition: Sets critical
parameters such as epochs (100),
batch_size (16), and img_size (640).
Augmentation Strategy: Configures geometric and
color-space augmentations (Mosaic, HSV shifts, Scale) to
artificially increase dataset diversity and prevent overfitting.
Smart Resume Logic: The script checks if a
training run was interrupted. If valid weights exist in Google
Drive, it automatically resumes training from the last
checkpoint; otherwise, it starts a new session.
Training Loop: Calls
model.train() to begin the backpropagation process,
optimizing the weights to detect the five life stages.
# A unique identifier for this specific training run, used for naming output folders.experiment_name ='run_v1'# The total number of times the model will iterate over the entire training dataset.epochs =100# The number of images processed in a single forward/backward pass of the model.batch_size =16# The resolution (in pixels) to which all input images will be resized before training.img_size =640# The number of consecutive epochs with no improvement in validation metrics before training is stopped early.patience =15# The initial step size for the optimizer's weight updates.learning_rate =0.005# The optimization algorithm to be used. 'auto' allows Ultralytics to select a suitable default.optimizer_type ='auto'# Probability of applying the mosaic augmentation, which combines four images into one.mosaic_probability =1.0# The degree of random hue shift applied in HSV color-space augmentation.hsv_hue_fraction =0.015# The degree of random saturation shift applied in HSV color-space augmentation.hsv_saturation_fraction =0.7# The degree of random brightness shift applied in HSV color-space augmentation.hsv_brightness_fraction =0.0# The range (in degrees) of random rotation applied to images.rotation_degrees =0.0# The probability of vertically flipping an image.flip_vertical_probability =0.0# The probability of horizontally flipping an image.flip_horizontal_probability =0.0# The gain for applying random scaling (zoom in/out) to images.scale_gain =0.5# A boolean flag to force the training to start from scratch, ignoring any existing checkpoints.force_restart =False# The default path to the 'last.pt' checkpoint file for automatic training resumption.auto_resume_path = os.path.join(weights_path, 'last.pt')# An optional path to a specific weight file to resume from, overriding the auto-resume logic if set.specific_weight_path =""# Prints a formatted summary of the key training and augmentation configurations for user verification.print("-"*40)print("Training Configuration")print(f"{'Experiment Name':<25} : {experiment_name}")print(f"{'Epochs':<25} : {epochs}")print(f"{'Batch Size':<25} : {batch_size}")print(f"{'Image Size':<25} : {img_size}")print(f"{'Learning Rate':<25} : {learning_rate}")print("-"*40)print("Augmentation Strategy")print(f"{'Mosaic Probability':<25} : {mosaic_probability}")print(f"{'Hue Fraction':<25} : {hsv_hue_fraction}")print(f"{'Saturation Fraction':<25} : {hsv_saturation_fraction}")print(f"{'Brightness Fraction':<25} : {hsv_brightness_fraction}")print(f"{'Rotation Degrees':<25} : {rotation_degrees}")print(f"{'Vertical Flip Prob':<25} : {flip_vertical_probability}")print(f"{'Horizontal Flip Prob':<25} : {flip_horizontal_probability}")print("-"*40)print(f"{'Auto Resume Path':<25} : {auto_resume_path}")
import time# Checkpoint Resumption Logic# Initializes the default training mode to start a new session.resume_mode =False# Sets the default weights to the pre-trained model file.try: weights_to_load = standard_weights_fileexceptNameError:# Fallback if Cell 7 variables are lost. weights_to_load ="yolov8l.pt"# Checks for a specific, manually provided weight file to resume from.if specific_weight_path and os.path.exists(specific_weight_path):print(f"Manual override detected.\nResuming from specific weight: {specific_weight_path}") resume_mode =True weights_to_load = specific_weight_path# If no manual override is given, checks for the default auto-resume checkpoint file.elif os.path.exists(auto_resume_path):# Handles the case where a checkpoint exists but a fresh start is explicitly required.if force_restart:print(f"Previous run found at {auto_resume_path}, but 'force_restart' is True.")print("Starting new training...") resume_mode =False# If a checkpoint exists and a fresh start is not forced, sets up auto-resumption.else:print(f"Previous run detected.\nAuto-resuming from: {auto_resume_path}") resume_mode =True weights_to_load = auto_resume_path# If no checkpoints are found, configures the script to start a new training session.else:print("No previous run found. Starting new training...")# Training Executionprint("\nInitializing training loop...")# Records the timestamp at the beginning of the training process to measure total duration.start_time = time.time()# Executes the appropriate training command based on whether a session is being resumed.if resume_mode:# When resuming, the `resume` argument is used. results = model.train( resume=weights_to_load, project=project_path, name=experiment_name )else:# For a new run, all hyperparameters and augmentation settings are passed explicitly. results = model.train(# Core Parameters data=yaml_path, epochs=epochs, imgsz=img_size, batch=batch_size, device=0, patience=patience,# Project and Checkpointing save=True, save_period=1, project=project_path, name=experiment_name, exist_ok=True, plots=True,# Optimizer Settings lr0=learning_rate, optimizer=optimizer_type,# Augmentation Parameters mosaic=mosaic_probability, hsv_h=hsv_hue_fraction, hsv_s=hsv_saturation_fraction, hsv_v=hsv_brightness_fraction, degrees=rotation_degrees, flipud=flip_vertical_probability, fliplr=flip_horizontal_probability, scale=scale_gain )# Records the timestamp at the end of the training process.end_time = time.time()# Calculates the total training duration and formats it into hours, minutes, and seconds.duration_seconds = end_time - start_timehours =int(duration_seconds //3600)minutes =int((duration_seconds %3600) //60)seconds =int(duration_seconds %60)# Prints a final summary of the completed training session and its duration.print("\n"+"-"*30)print("Training complete.")print(f"Total Time: {hours}h {minutes}m {seconds}s")print("-"*30)
Purpose: To assess the model’s accuracy and
reliability using visual metrics. Raw numbers are often
insufficient; visualization helps identify specific classes or
scenarios where the model struggles.
Key Activities:
Training Metrics Plotting: Parses the
results.csv log file to generate line charts for
Box Loss, Classification Loss, Precision, Recall, and mAP (Mean
Average Precision). This helps verify that the model is
converging and not overfitting.
Confusion Matrix Generation: Runs a validation
pass on the best-performing model to generate a normalized
Confusion Matrix. This heatmap visualizes how often the model
confuses one life stage with another, providing insight into
biological similarities affecting the AI.
import matplotlib.pyplot as pltimport pandas as pdimport seaborn as sns# Constructs the full path to the results log file.results_csv_path = os.path.join(project_path, experiment_name, 'results.csv')# The number of data points to average over when applying a rolling mean.smoothing_window_size =5# Sets a professional and consistent visual theme for all generated plots.sns.set_theme(style="whitegrid", context="notebook", font_scale=1.1)# Increases the resolution of the output figures for better clarity.plt.rcParams['figure.dpi'] =120# Defines a perceptually uniform colormap to derive a consistent color palette.color_map = plt.get_cmap('plasma')# Assigns specific colors from the colormap to different metrics for consistency.color_train = color_map(0.0) # Color for training metrics.color_val = color_map(0.6) # Color for validation metrics.color_prec = color_map(0.25) # Color for the precision curve.color_rec = color_map(0.75) # Color for the recall curve.color_map_metric = color_map(0.5) # Color for the mAP metric.color_lr = color_map(0.05) # Color for the learning rate schedule.def plot_training_results(csv_file_path):""" Reads a results.csv file and generates a 2x3 grid of training performance plots. Args: csv_file_path (str): The full path to the results.csv file. """# Verifies the existence of the results file before attempting to read it.ifnot os.path.exists(csv_file_path):print(f"Error: Could not find results at {csv_file_path}")return# Reads the CSV data into a pandas DataFrame. df = pd.read_csv(csv_file_path)# Cleans up column names by removing leading/trailing whitespace. df.columns = df.columns.str.strip()# Creates a Matplotlib figure and an array of 2x3 subplots (axes). figure, axis_array = plt.subplots(2, 3, figsize=(18, 10))# Defines the mapping of DataFrame columns to their respective plot titles for the three primary loss functions. loss_map = [ ('train/box_loss', 'val/box_loss', 'Box Loss'), ('train/cls_loss', 'val/cls_loss', 'Classification Loss'), ('train/dfl_loss', 'val/dfl_loss', 'Distribution Focal Loss') ]# Iterates through the loss map to generate the top row of plots.for i, (train_col, val_col, title) inenumerate(loss_map): axis = axis_array[0, i]# Plots the raw, noisy data with low opacity to serve as a background reference. sns.lineplot(data=df, x=df.index, y=train_col, ax=axis, color=color_train, alpha=0.15) sns.lineplot(data=df, x=df.index, y=val_col, ax=axis, color=color_val, alpha=0.15)# Overlays the smoothed data using a rolling mean for clearer trend visualization. sns.lineplot(x=df.index, y=df[train_col].rolling(smoothing_window_size).mean(), ax=axis, color=color_train, linewidth=2.5, label='Train') sns.lineplot(x=df.index, y=df[val_col].rolling(smoothing_window_size).mean(), ax=axis, color=color_val, linewidth=2.5, label='Validation')# Configures the title, labels, and legend for each loss plot. axis.set_title(title, color='#333333') axis.set_xlabel('Epochs') axis.set_ylabel('Loss Value') axis.legend()# Precision and Recall Plot axis_precision_recall = axis_array[1, 0] sns.lineplot(x=df.index, y=df['metrics/precision(B)'].rolling(smoothing_window_size).mean(), ax=axis_precision_recall, color=color_prec, label='Precision') sns.lineplot(x=df.index, y=df['metrics/recall(B)'].rolling(smoothing_window_size).mean(), ax=axis_precision_recall, color=color_rec, label='Recall') axis_precision_recall.set_title('Precision & Recall') axis_precision_recall.set_xlabel('Epochs') axis_precision_recall.set_ylabel('Score') axis_precision_recall.set_ylim(0, 1)# Mean Average Precision (mAP) Plot axis_map = axis_array[1, 1] sns.lineplot(x=df.index, y=df['metrics/mAP50(B)'].rolling(smoothing_window_size).mean(), ax=axis_map, color=color_map_metric, linewidth=2.5, label='mAP @ 0.50') axis_map.set_title('Mean Average Precision (IoU=0.50)') axis_map.set_xlabel('Epochs') axis_map.set_ylabel('Score') axis_map.set_ylim(0, 1)# Fills the area under the mAP curve to visually emphasize the performance metric. axis_map.fill_between(df.index, df['metrics/mAP50(B)'].rolling(smoothing_window_size).mean(), color=color_map_metric, alpha=0.1)# Learning Rate Schedule Plot axis_learning_rate = axis_array[1, 2] sns.lineplot(x=df.index, y=df['lr/pg0'], ax=axis_learning_rate, color=color_lr, linestyle='--') axis_learning_rate.set_title('Learning Rate Schedule') axis_learning_rate.set_xlabel('Epochs') axis_learning_rate.set_ylabel('Learning Rate')# Adjusts the spacing between subplots to prevent labels from overlapping. plt.tight_layout( pad=3.0, w_pad=4.0, h_pad=5.0 )# Renders and displays the final, complete figure. plt.show()# Executes the plotting function with the path to the results file.plot_training_results(results_csv_path)
import numpy as np# Defines the path to the best-performing model weights saved during training.best_weight_path = os.path.join(project_path, experiment_name, 'weights', 'best.pt')# Defines the output path for the final, publication-quality confusion matrix image.output_image_path = os.path.join(project_path, experiment_name, 'confusion_matrix.png')# Ensures that the script only runs if the best weight file exists.if os.path.exists(best_weight_path):print(f"Loading weights from: {best_weight_path}")# Instantiates a new YOLO model object using the best saved weights. validation_model = YOLO(best_weight_path)# Runs a validation pass on the model. validation_metrics = validation_model.val( data=yaml_path, split='val', plots=True, device=0, batch=16, conf=0.001 )# Extracts the raw confusion matrix (a NumPy array) from the results. raw_matrix = validation_metrics.confusion_matrix.matrix num_classes =5# Slices the matrix to ensure it only contains the defined classes. matrix_data = raw_matrix[:num_classes, :num_classes]# Retrieves the class names from the validation results and formats them. raw_names =list(validation_metrics.names.values())[:num_classes] class_names = [name.replace('_', ' ').title() for name in raw_names]# Normalizes the confusion matrix by rows to calculate recall scores. row_sums = matrix_data.sum(axis=1, keepdims=True)# Replaces zero sums with a small number to avoid division-by-zero errors for classes that may not have appeared in the validation set. row_sums[row_sums ==0] =1e-9 matrix_normalized = matrix_data / row_sums# Initializes a high-resolution figure for the plot. plt.figure(figsize=(16, 12), dpi=300) sns.set_theme(style="white", font_scale=1.1)# Creates the heatmap using Seaborn, configuring annotations, colormap, and labels. axis = sns.heatmap( matrix_normalized, annot=True, # Displays the numerical value in each cell. annot_kws={"size": 14}, # Sets the font size for annotations. fmt='.2f', # Formats annotations to two decimal places. cmap='Blues', # Sets the color scheme. xticklabels=class_names, yticklabels=class_names, vmin=0.0, vmax=1.0, # Fixes the color bar range from 0 to 1. square=True, # Enforces square cells for better proportionality. linewidths=2.5, linecolor='white', cbar_kws={'shrink': 0.6, # Adjusts the size of the color bar.'pad': 0.04 } )# Configures the color bar label to clarify that the values represent recall. cbar = axis.collections[0].colorbar cbar.set_label('Recall (Sensitivity)', labelpad=30, fontsize=14)# Sets the main title and axis labels with appropriate padding. plt.title('Confusion Matrix', fontsize=20, pad=30) plt.xlabel('Predicted Class', fontsize=16, labelpad=25) plt.ylabel('Actual Class', fontsize=16, labelpad=25)# Adjusts tick label appearance for clarity. plt.xticks(rotation=0, fontsize=13) plt.yticks(rotation=0, fontsize=13)# Adjusts subplot parameters to give a tight layout. plt.tight_layout(pad=5.0)# Saves the final figure to the specified output path. plt.savefig(output_image_path, dpi=300, bbox_inches='tight')print(f"Confusion Matrix saved to: {output_image_path}")# Displays the plot in the notebook output. plt.show()else:# Prints an error message if the required 'best.pt' file is not found.print(f"Error: Best weights not found at {best_weight_path}")
Purpose: To evaluate the model’s operational
efficiency (speed) and prepare it for advanced inference scenarios
involving high-resolution imagery.
Key Activities:
Speed Benchmarking: Runs a validation pass
specifically to measure pre-processing, inference, and
post-processing times. This calculates the estimated FPS (Frames
Per Second) to determine if the model is suitable for real-time
applications.
SAHI Wrapper Initialization: Initializes the
Sliced Aided
Hyper Inference wrapper.
Standard YOLO resizing can make small insects vanish in 4K
images. SAHI solves this by slicing the image into smaller
overlapping windows, performing inference on each, and stitching
the results back together.
# Defines the path to the best-performing model weights from the training run.best_weight_path = os.path.join(project_path, experiment_name, 'weights', 'best.pt')# Ensures the script proceeds only if the required model weight file is found.if os.path.exists(best_weight_path):print(f"Loading best model for benchmarking: {best_weight_path}")# Loads the best model weights into a YOLO object for evaluation. benchmark_model = YOLO(best_weight_path)# Runs a validation pass on the specified dataset split. metrics = benchmark_model.val(data=yaml_path, split='val', plots=False, device=0)# Extracts the speed dictionary, which contains timing information for different stages of the inference pipeline. speed_metrics = metrics.speed# Displays a formatted summary of the average inference speed metrics.print("\n"+"-"*45)print("Inference Speed Benchmark (Average per Image)")print("-"*45)print(f"{'Pre-process':<25} : {speed_metrics['preprocess']:.2f} ms")print(f"{'Inference (Model)':<25} : {speed_metrics['inference']:.2f} ms")print(f"{'Post-process (NMS)':<25} : {speed_metrics['postprocess']:.2f} ms")print("-"*45)# Calculates the total latency by summing the timings of all pipeline stages. total_latency =sum(speed_metrics.values())print(f"{'Total Latency':<25} : {total_latency:.2f} ms")# Estimates the throughput in Frames Per Second (FPS) based on the total latency. fps =1000/ total_latencyprint(f"{'Estimated FPS':<25} : {fps:.2f} fps")print("-"*45)else:# Prints an error message if the model weights file could not be located.print("Error: Best weights file not found.")print("Please make sure that the training completed successfully.")
from sahi import AutoDetectionModelfrom sahi.predict import get_sliced_prediction# Defines the path to the best-performing model weights from the training run, which will be loaded into the SAHI wrapper.best_weight_path = os.path.join(project_path, experiment_name, 'weights', 'best.pt')# The height of each individual slice in pixels.slice_height =640# The width of each individual slice in pixels.slice_width =640# The percentage of overlap between adjacent slices vertically.overlap_height_ratio =0.2# The percentage of overlap between adjacent slices horizontally.overlap_width_ratio =0.2print(f"Initializing SAHI wrapper for: {best_weight_path}")# Verifies that the model weight file exists before attempting to load it.if os.path.exists(best_weight_path):# Initializes the SAHI AutoDetectionModel. detection_model = AutoDetectionModel.from_pretrained( model_type='yolov8', # Specifies the model architecture. model_path=best_weight_path, # Provides the path to the custom-trained weights. confidence_threshold=0.25, # Sets the minimum confidence for a detection to be considered valid. device="cuda:0"# Assigns the model to a specific GPU device for inference. )# Prints a confirmation message summarizing the SAHI configuration.print("\n"+"-"*45)print("SAHI Model Ready")print("-"*45)print(f"{'Slice Dimensions':<20} : {slice_height}x{slice_width}")print(f"{'Overlap Ratio':<20} : {overlap_height_ratio *100}%")print(f"{'Confidence Thresh':<20} : 0.25")print("-"*45)else:# Handles the case where the required weight file is not found.print(f"Error: Weights not found at {best_weight_path}")print("Cannot initialize SAHI.")
Section 7: Deployment Export and Interactive Demonstration
Purpose: To finalize the model for production
deployment and provide a tangible demonstration of its
capabilities on user-provided data.
Key Activities:
ONNX Export: Converts the PyTorch model
(.pt) to the ONNX (Open Neural Network Exchange)
format. This format is hardware-agnostic and optimized for
deployment on edge devices or web servers.
Interactive Inference Pipeline: A comprehensive
script that:
Accepts a user-uploaded image.
Detects the scene type (Macro vs. Field) to choose between
Standard or Sliced inference.
Filters outliers based on box area to reduce false
positives.
Draws bounding boxes and creates a summary legend of the
detected entomological evidence.
# Defines the source path for the best PyTorch model weights and the desired target path and filename for the exported ONNX model.source_weights = os.path.join(project_path, experiment_name, 'weights', 'best.pt')target_filename ="yolov8_mortiscope.onnx"target_path = os.path.join(project_path, experiment_name, 'weights', target_filename)print(f"Loading weights from: {source_weights}")# Verifies the existence of the source weight file before initiating the export process.if os.path.exists(source_weights):# Loads the trained PyTorch model from the specified '.pt' file. model = YOLO(source_weights)# Executes the model export process with specific configurations. exported_path = model.export(format='onnx', # Specifies the target export format as ONNX. dynamic=False, # Exports the model with fixed input/output dimensions for performance. simplify=True, # Applies the ONNX-Simplifier to optimize the model graph. opset=12# Sets the ONNX operator set version for broad compatibility. )# After export, the file is renamed and moved to the final target location.ifisinstance(exported_path, str):# The `export` method saves the file with a default name. shutil.move(exported_path, target_path)print("\n"+"-"*100)print(f"File Saved: {target_path}")print("-"*100)else:# Handles cases where the export process does not return a valid file path.print("Export returned unexpected format.")else:# Provides an error message if the source PyTorch model weights are not found.print(f"Error: Could not find weights at {source_weights}")
from collections import Counter, defaultdictimport cv2import matplotlib.patches as mpatchesfrom google.colab import filesfrom sahi.prediction import ObjectPrediction# Global Configuration# The minimum confidence score required for a detection to be considered valid.confidence_threshold =0.25# The target width in pixels for the final output visualization.target_width =3840# The target height in pixels for the final output visualization.target_height =2160# The resolution (Dots Per Inch) for the generated Matplotlib figure.dpi =100# A standardized color palette for different insect life stages.color_map = {"instar_1": "#eab308","instar_2": "#84cc16","instar_3": "#22c55e","pupa": "#f97316","adult": "#f43f5e"}# Defines the canonical order for presenting life stages in summaries and legends.lifecycle_order = ["instar_1", "instar_2", "instar_3", "pupa", "adult"]def hex_to_bgr(hex_color):""" Converts a hexadecimal color string to a BGR tuple for use with OpenCV. Args: hex_color (str): The color in hexadecimal format (e.g., '#eab308'). Returns: tuple: The color in BGR format (e.g., (8, 179, 234)). """ hex_color = hex_color.lstrip('#') rgb =tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))# OpenCV uses BGR order, so the RGB tuple is reversed.return (rgb[2], rgb[1], rgb[0])def format_class_name(name):""" Formats an internal class name into a human-readable, title-cased string. Args: name (str): The internal class name (e.g., 'instar_1'). Returns: str: The formatted name (e.g., 'Instar 1'). """return name.replace("_", " ").title()def calculate_iou(box1, box2):""" Calculates the Intersection over Union (IoU) of two bounding boxes. Args: box1 (sahi.prediction.BBox): The first bounding box. box2 (sahi.prediction.BBox): The second bounding box. Returns: float: The IoU score, a value between 0.0 and 1.0. """# Extracts coordinates for easier calculation. b1 = [box1.minx, box1.miny, box1.maxx, box1.maxy] b2 = [box2.minx, box2.miny, box2.maxx, box2.maxy]# Determines the coordinates of the intersection rectangle. x1 =max(b1[0], b2[0]) y1 =max(b1[1], b2[1]) x2 =min(b1[2], b2[2]) y2 =min(b1[3], b2[3])# Computes the area of intersection. intersection =max(0, x2 - x1) *max(0, y2 - y1)# Computes the area of both bounding boxes. area1 = (b1[2] - b1[0]) * (b1[3] - b1[1]) area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])# Computes the area of the union. union = area1 + area2 - intersectionif union ==0:return0return intersection / uniondef apply_class_agnostic_nms(predictions, iou_threshold=0.6):""" Applies a custom class-agnostic Non-Maximum Suppression (NMS) to a list of object predictions to filter out highly overlapping boxes. Args: predictions (list[ObjectPrediction]): A list of SAHI ObjectPrediction objects. iou_threshold (float): The IoU threshold above which boxes are suppressed. Returns: list[ObjectPrediction]: A filtered list of object predictions. """# Sorts predictions by confidence score in descending order. sorted_preds =sorted(predictions, key=lambda x: x.score.value, reverse=True) kept_preds = []for current in sorted_preds: should_keep =Truefor kept in kept_preds: iou = calculate_iou(current.bbox, kept.bbox)if iou > iou_threshold:# Suppresses the current box if it has a high IoU with an already kept box. should_keep =Falsebreakif should_keep: kept_preds.append(current)return kept_predsdef detect_scene_type(sahi_model, image_path):""" Analyzes an image to determine if it is a 'macro' (close-up) or 'field' (wide-angle) scene. This heuristic is based on the average relative area of objects detected in an initial, low-resolution pass. Large average areas suggest a macro shot. Args: sahi_model (sahi.AutoDetectionModel): The initialized SAHI model. image_path (str): The path to the image file. Returns: str: The detected scene type, either 'macro' or 'field'. """ native_model = sahi_model.model results = native_model.predict(image_path, imgsz=640, conf=0.25, verbose=False) boxes = results[0].boxesiflen(boxes) ==0:return"field"# Calculates the normalized area (width * height) of each detected box. areas = boxes.xywhn[:, 2] * boxes.xywhn[:, 3] avg_area = torch.mean(areas).item()# Classifies the scene based on a predefined area threshold.if avg_area >0.015:return"macro"else:return"field"def run_image_analysis():""" Orchestrates the main image analysis workflow. This function handles the user file upload, selects an inference strategy based on the scene type, processes the detections, applies filtering, and generates a final visual report with annotations and a summary legend. """print("Click button to upload image:") uploaded_files = files.upload()ifnot uploaded_files:print("No file uploaded.")returnfor filename in uploaded_files.keys():print(f"\nProcessing {filename}")# Determines the appropriate inference strategy for the uploaded image. scene_type = detect_scene_type(detection_model, filename) image_cv = cv2.imread(filename) img_h, img_w, _ = image_cv.shape img_area = img_w * img_h object_prediction_list = []# Scene-Adaptive Inferenceif scene_type =="macro":# For close-up images, use the standard, non-sliced prediction method. native_model = detection_model.model results = native_model.predict( filename, conf=0.45, imgsz=640, augment=False, agnostic_nms=True, # Uses YOLO's built-in NMS. verbose=False )# Manually converts the native YOLO results into the SAHI ObjectPrediction format.for r in results:for box in r.boxes: x1, y1, x2, y2 = box.xyxy[0].tolist() score = box.conf[0].item() cls_id =int(box.cls[0].item()) cls_name = native_model.names[cls_id] obj = ObjectPrediction( bbox=[x1, y1, x2, y2], category_id=cls_id, category_name=cls_name, score=score ) object_prediction_list.append(obj) use_outlier_filter =Falseelse:# For wide-angle images, use SAHI's sliced prediction method.if img_w <2500or img_h <2500: current_slice_size =160 current_overlap =0.35else: current_slice_size =320 current_overlap =0.25 result = get_sliced_prediction( filename, detection_model, slice_height=current_slice_size, slice_width=current_slice_size, overlap_height_ratio=current_overlap, overlap_width_ratio=current_overlap, postprocess_type="NMS", postprocess_match_metric="IOS", postprocess_match_threshold=0.5, postprocess_class_agnostic=True, verbose=1 ) object_prediction_list = result.object_prediction_list use_outlier_filter =True# Applies a final class-agnostic NMS pass to refine the results. object_prediction_list = apply_class_agnostic_nms(object_prediction_list, iou_threshold=0.6) class_counts = Counter() class_confidences = defaultdict(list)# Calculates the median area of all detections to use for outlier filtering. all_areas = []for pred in object_prediction_list:if pred.score.value >= confidence_threshold: bbox = pred.bbox area = (bbox.maxx - bbox.minx) * (bbox.maxy - bbox.miny) all_areas.append(area) median_area = np.median(all_areas) if all_areas else0# Iterates through predictions to filter outliers and draw annotations.for prediction in object_prediction_list:if prediction.score.value < confidence_threshold:continue bbox = prediction.bbox box_area = (bbox.maxx - bbox.minx) * (bbox.maxy - bbox.miny)# Applies an outlier filter to remove unusually large detections, which are often false positives in field images.if use_outlier_filter and median_area >0: coverage_ratio = box_area / img_area# Skip if box covers >5% of image.if coverage_ratio >0.05:continue# Skip if box is >15x median area.if box_area > (median_area *15):continue# Aggregates statistics for the final summary. class_name = prediction.category.name score = prediction.score.value class_counts[class_name] +=1 class_confidences[class_name].append(score)# Draws the bounding box rectangle onto the image. color_hex = color_map.get(class_name, "#ffffff") color_bgr = hex_to_bgr(color_hex) x_min, y_min =int(bbox.minx), int(bbox.miny) x_max, y_max =int(bbox.maxx), int(bbox.maxy) cv2.rectangle(image_cv, (x_min, y_min), (x_max, y_max), color_bgr, 2)# Visualization and Reporting# Converts the OpenCV (BGR) image to RGB for Matplotlib display. img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) fig_w_in, fig_h_in = target_width / dpi, target_height / dpi# Creates a two-panel figure: one for the image, one for the legend. fig, (ax_image, ax_legend) = plt.subplots(1, 2, figsize=(fig_w_in, fig_h_in), dpi=dpi, gridspec_kw={'width_ratios': [3, 1], 'wspace': 0.05} ) ax_image.imshow(img_rgb) ax_image.axis('off') ax_legend.axis('off')# Constructs the legend handles from the aggregated detection data. legend_handles = []for class_key in lifecycle_order:if class_key in class_counts: count = class_counts[class_key] scores = class_confidences[class_key] avg_score =sum(scores) /len(scores) if scores else0 label_text =f"{format_class_name(class_key)}: {count} — {avg_score *100:.2f}%" patch = mpatches.Patch(color=color_map.get(class_key, "#000"), label=label_text) legend_handles.append(patch)# Renders the legend on the right-hand panel.if legend_handles: ax_legend.legend( handles=legend_handles, loc='center', title="Detection Summary", fontsize=24, title_fontsize=30, frameon=False, labelspacing=0.8 ) plt.tight_layout() plt.show()# Prints a final textual summary to the console.print("\n"+"-"*70)print(f"Image report: {filename}")print("-"*70)for class_key in lifecycle_order:if class_key in class_counts: scores = class_confidences[class_key] avg =sum(scores)/len(scores)print(f"{format_class_name(class_key):<15} | Count: {class_counts[class_key]:<5} | Avg Conf: {avg*100:.2f}%")print("-"*70)# Executes the main analysis function.run_image_analysis()