Files
filmsim/wb.py
2025-06-19 15:31:45 -04:00

242 lines
9.3 KiB
Python

import torch
import numpy as np
import imageio.v2 as imageio
import time
import os
# --- Configuration & Constants ---
# These are the same as the NumPy version but will be used on PyTorch tensors.
DEFAULT_T_THRESHOLD = 0.1321
DEFAULT_MU_STEP = 0.0312
SCALE_FACTOR = 255.0
DEFAULT_A_THRESHOLD = 0.8 / SCALE_FACTOR
DEFAULT_B_THRESHOLD = 0.15 / SCALE_FACTOR
DEFAULT_MAX_ITERATIONS = 60
# --- PyTorch Core Algorithm Functions ---
def rgb_to_yuv_torch(image_tensor: torch.Tensor, matrix_torch: torch.Tensor) -> torch.Tensor:
"""
Converts an RGB image tensor to the paper's YUV color space using PyTorch.
Args:
image_tensor (torch.Tensor): A (H, W, 3) tensor on a target device.
matrix_torch (torch.Tensor): The 3x3 conversion matrix on the same device.
Returns:
torch.Tensor: An (H, W, 3) YUV tensor on the same device.
"""
return torch.matmul(image_tensor, matrix_torch)
def k_function_torch(error: float, a: float, b: float) -> float:
"""
Implements the non-linear error weighting function K(x) from Eq. 16.
This function remains on the CPU as it operates on a single scalar value.
"""
abs_error = abs(error)
sign = np.sign(error)
if abs_error >= a:
return 2.0 * sign
elif abs_error >= b:
return 1.0 * sign
else:
return 0.0
def huo_awb_core_torch(image_tensor: torch.Tensor,
t_threshold: float,
mu: float,
a: float,
b: float,
max_iter: int,
device: torch.device) -> torch.Tensor:
"""
Performs the core iterative AWB algorithm using PyTorch tensors on a specified device.
Args:
image_tensor (torch.Tensor): Input image as a (H, W, 3) float32 tensor on the target device.
(other params): Algorithm configuration constants.
device (torch.device): The device (e.g., 'cuda' or 'cpu') to run on.
Returns:
torch.Tensor: A (3,) tensor containing the final calculated [R, G, B] gains.
"""
# Create the YUV conversion matrix and gains tensor on the target device
yuv_matrix = torch.tensor([
[0.299, 0.587, 0.114],
[-0.299, -0.587, 0.886],
[0.701, -0.587, -0.114]
], dtype=torch.float32, device=device).T
gains = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32, device=device)
print(f"Starting iterative AWB on device: '{device.type}'...")
for i in range(max_iter):
# 1. Apply current gains to the image (all on GPU/device)
balanced_image = torch.clamp(image_tensor * gains, 0.0, 1.0)
# 2. Convert to YUV (on GPU/device)
yuv_image = rgb_to_yuv_torch(balanced_image, yuv_matrix)
Y, U, V = yuv_image.unbind(dim=-1)
# 3. Identify gray points (on GPU/device)
# Luminance mask to exclude overly dark or bright pixels
luminance_mask = (Y > 0.1) & (Y < 0.95)
if not torch.any(luminance_mask):
print(f"Iteration {i+1}: No pixels in luminance range. Stopping.")
break
Y_masked = Y[luminance_mask]
U_masked = U[luminance_mask]
V_masked = V[luminance_mask]
# Criterion from Eq. 10
gray_mask_indices = (torch.abs(U_masked) + torch.abs(V_masked)) / Y_masked < t_threshold
gray_points_U = U_masked[gray_mask_indices]
num_gray_points = gray_points_U.shape[0]
if num_gray_points < 50: # Use a higher threshold for large images
print(f"Iteration {i+1}: Not enough gray points found ({num_gray_points}). Stopping.")
break
# 4. Calculate average chrominance (reduction on GPU/device)
u_mean = torch.mean(gray_points_U)
v_mean = torch.mean(V_masked[gray_mask_indices])
# Bring the scalar results back to CPU for control flow
u_mean_cpu = u_mean.item()
v_mean_cpu = v_mean.item()
# Check for convergence
if abs(u_mean_cpu) < b and abs(v_mean_cpu) < b:
print(f"Iteration {i+1}: Converged. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}")
break
# 5. Determine adjustment (logic on CPU, gain update on GPU/device)
if abs(u_mean_cpu) > abs(v_mean_cpu):
error = -u_mean_cpu
adjustment = mu * k_function_torch(error, a, b)
gains[2] += adjustment
print(f"Iter {i+1}: Adjusting B-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, B-adj={adjustment:.4f}")
else:
error = -v_mean_cpu
adjustment = mu * k_function_torch(error, a, b)
gains[0] += adjustment
print(f"Iter {i+1}: Adjusting R-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, R-adj={adjustment:.4f}")
print(f"Final gains: R={gains[0].item():.4f}, G={gains[1].item():.4f}, B={gains[2].item():.4f}")
return gains
# --- Main Public Function ---
def apply_huo_awb_torch(image_path: str, output_path: str, **kwargs):
"""
Loads a high-resolution 16-bit TIFF, applies the Huo et al. AWB algorithm
using PyTorch for high performance, and saves the result.
Args:
image_path (str): Path to the input 16-bit TIFF image.
output_path (str): Path to save the white-balanced 16-bit TIFF image.
**kwargs: Optional algorithm parameters.
"""
start_time = time.perf_counter()
# 1. Select Device (GPU if available, otherwise CPU)
# if torch.cuda.is_available():
# device = torch.device('cuda')
# elif torch.backends.mps.is_available(): # For Apple Silicon
# device = torch.device('mps')
# else:
device = torch.device('cpu')
print(f"Using device: {device}")
# 2. Load Image with imageio (on CPU)
print(f"Loading image from: {image_path}")
try:
image_np = imageio.imread(image_path)
except FileNotFoundError:
print(f"Error: The file '{image_path}' was not found.")
return
load_time = time.perf_counter()
print(f"Image loaded in {load_time - start_time:.2f} seconds.")
# 3. Pre-process and Move to Device
# Normalize to float32 and convert to PyTorch tensor
image_float_np = image_np.astype(np.float32) / 65535.0
# Move the large image tensor to the selected device
image_tensor = torch.from_numpy(image_float_np).to(device)
transfer_time = time.perf_counter()
print(f"Data transferred to {device.type} in {transfer_time - load_time:.2f} seconds.")
# 4. Run the core algorithm on the device
params = {
't_threshold': kwargs.get('t_threshold', DEFAULT_T_THRESHOLD),
'mu': kwargs.get('mu', DEFAULT_MU_STEP),
'a': kwargs.get('a', DEFAULT_A_THRESHOLD),
'b': kwargs.get('b', DEFAULT_B_THRESHOLD),
'max_iter': kwargs.get('max_iter', DEFAULT_MAX_ITERATIONS),
}
gains = huo_awb_core_torch(image_tensor, device=device, **params)
process_time = time.perf_counter()
print(f"AWB processing finished in {process_time - transfer_time:.2f} seconds.")
# 5. Apply final gains, move back to CPU, and save
corrected_image_tensor = torch.clamp(image_tensor * gains, 0.0, 1.0)
# Move tensor back to CPU for conversion to NumPy
corrected_image_np = corrected_image_tensor.cpu().numpy()
# Convert back to 16-bit integer for saving
corrected_image_uint16 = (corrected_image_np * 65535).astype(np.uint16)
print(f"Saving corrected image to: {output_path}")
imageio.imwrite(output_path, corrected_image_uint16)
end_time = time.perf_counter()
print(f"Image saved. Total time: {end_time - start_time:.2f} seconds.")
# --- Example Usage ---
if __name__ == '__main__':
# Create a dummy 50MP, 16-bit TIFF with a bluish cast
# 50MP is approx. 8660 x 5773 pixels
h, w = 5773, 8660
print(f"Creating a sample {h*w/1e6:.1f}MP 16-bit TIFF image with a bluish cast...")
# A gray gradient (create a smaller version and resize to save memory/time)
small_w = w // 10
gray_base_small = np.linspace(0.2, 0.8, small_w, dtype=np.float32)
gray_image_small = np.tile(gray_base_small, (h // 10, 1))
# Use PyTorch to resize efficiently if possible, otherwise numpy/scipy
try:
import torch.nn.functional as F
gray_image = F.interpolate(
torch.from_numpy(gray_image_small)[None, None, ...],
size=(h, w),
mode='bilinear',
align_corners=False
)[0, 0, ...].numpy()
except (ImportError, ModuleNotFoundError):
print("Resizing with a simpler method as full torch/cv2 not available for generation.")
gray_image = np.tile(np.linspace(0.2, 0.8, w, dtype=np.float32), (h, 1))
image_float = np.stack([gray_image, gray_image, gray_image], axis=-1)
# Apply a bluish cast (decrease R, increase B)
blue_cast = np.array([0.85, 1.0, 1.15], dtype=np.float32)
image_float_cast = np.clip(image_float * blue_cast, 0, 1)
image_uint16_cast = (image_float_cast * 65535).astype(np.uint16)
input_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz.tiff"
output_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz_corrected.tiff"
# Run the white balance algorithm on the sample image
apply_huo_awb_torch(input_filename, output_filename)