import torch import numpy as np import imageio.v2 as imageio import time import os # --- Configuration & Constants --- # These are the same as the NumPy version but will be used on PyTorch tensors. DEFAULT_T_THRESHOLD = 0.1321 DEFAULT_MU_STEP = 0.0312 SCALE_FACTOR = 255.0 DEFAULT_A_THRESHOLD = 0.8 / SCALE_FACTOR DEFAULT_B_THRESHOLD = 0.15 / SCALE_FACTOR DEFAULT_MAX_ITERATIONS = 60 # --- PyTorch Core Algorithm Functions --- def rgb_to_yuv_torch(image_tensor: torch.Tensor, matrix_torch: torch.Tensor) -> torch.Tensor: """ Converts an RGB image tensor to the paper's YUV color space using PyTorch. Args: image_tensor (torch.Tensor): A (H, W, 3) tensor on a target device. matrix_torch (torch.Tensor): The 3x3 conversion matrix on the same device. Returns: torch.Tensor: An (H, W, 3) YUV tensor on the same device. """ return torch.matmul(image_tensor, matrix_torch) def k_function_torch(error: float, a: float, b: float) -> float: """ Implements the non-linear error weighting function K(x) from Eq. 16. This function remains on the CPU as it operates on a single scalar value. """ abs_error = abs(error) sign = np.sign(error) if abs_error >= a: return 2.0 * sign elif abs_error >= b: return 1.0 * sign else: return 0.0 def huo_awb_core_torch(image_tensor: torch.Tensor, t_threshold: float, mu: float, a: float, b: float, max_iter: int, device: torch.device) -> torch.Tensor: """ Performs the core iterative AWB algorithm using PyTorch tensors on a specified device. Args: image_tensor (torch.Tensor): Input image as a (H, W, 3) float32 tensor on the target device. (other params): Algorithm configuration constants. device (torch.device): The device (e.g., 'cuda' or 'cpu') to run on. Returns: torch.Tensor: A (3,) tensor containing the final calculated [R, G, B] gains. """ # Create the YUV conversion matrix and gains tensor on the target device yuv_matrix = torch.tensor([ [0.299, 0.587, 0.114], [-0.299, -0.587, 0.886], [0.701, -0.587, -0.114] ], dtype=torch.float32, device=device).T gains = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32, device=device) print(f"Starting iterative AWB on device: '{device.type}'...") for i in range(max_iter): # 1. Apply current gains to the image (all on GPU/device) balanced_image = torch.clamp(image_tensor * gains, 0.0, 1.0) # 2. Convert to YUV (on GPU/device) yuv_image = rgb_to_yuv_torch(balanced_image, yuv_matrix) Y, U, V = yuv_image.unbind(dim=-1) # 3. Identify gray points (on GPU/device) # Luminance mask to exclude overly dark or bright pixels luminance_mask = (Y > 0.1) & (Y < 0.95) if not torch.any(luminance_mask): print(f"Iteration {i+1}: No pixels in luminance range. Stopping.") break Y_masked = Y[luminance_mask] U_masked = U[luminance_mask] V_masked = V[luminance_mask] # Criterion from Eq. 10 gray_mask_indices = (torch.abs(U_masked) + torch.abs(V_masked)) / Y_masked < t_threshold gray_points_U = U_masked[gray_mask_indices] num_gray_points = gray_points_U.shape[0] if num_gray_points < 50: # Use a higher threshold for large images print(f"Iteration {i+1}: Not enough gray points found ({num_gray_points}). Stopping.") break # 4. Calculate average chrominance (reduction on GPU/device) u_mean = torch.mean(gray_points_U) v_mean = torch.mean(V_masked[gray_mask_indices]) # Bring the scalar results back to CPU for control flow u_mean_cpu = u_mean.item() v_mean_cpu = v_mean.item() # Check for convergence if abs(u_mean_cpu) < b and abs(v_mean_cpu) < b: print(f"Iteration {i+1}: Converged. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}") break # 5. Determine adjustment (logic on CPU, gain update on GPU/device) if abs(u_mean_cpu) > abs(v_mean_cpu): error = -u_mean_cpu adjustment = mu * k_function_torch(error, a, b) gains[2] += adjustment print(f"Iter {i+1}: Adjusting B-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, B-adj={adjustment:.4f}") else: error = -v_mean_cpu adjustment = mu * k_function_torch(error, a, b) gains[0] += adjustment print(f"Iter {i+1}: Adjusting R-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, R-adj={adjustment:.4f}") print(f"Final gains: R={gains[0].item():.4f}, G={gains[1].item():.4f}, B={gains[2].item():.4f}") return gains # --- Main Public Function --- def apply_huo_awb_torch(image_path: str, output_path: str, **kwargs): """ Loads a high-resolution 16-bit TIFF, applies the Huo et al. AWB algorithm using PyTorch for high performance, and saves the result. Args: image_path (str): Path to the input 16-bit TIFF image. output_path (str): Path to save the white-balanced 16-bit TIFF image. **kwargs: Optional algorithm parameters. """ start_time = time.perf_counter() # 1. Select Device (GPU if available, otherwise CPU) # if torch.cuda.is_available(): # device = torch.device('cuda') # elif torch.backends.mps.is_available(): # For Apple Silicon # device = torch.device('mps') # else: device = torch.device('cpu') print(f"Using device: {device}") # 2. Load Image with imageio (on CPU) print(f"Loading image from: {image_path}") try: image_np = imageio.imread(image_path) except FileNotFoundError: print(f"Error: The file '{image_path}' was not found.") return load_time = time.perf_counter() print(f"Image loaded in {load_time - start_time:.2f} seconds.") # 3. Pre-process and Move to Device # Normalize to float32 and convert to PyTorch tensor image_float_np = image_np.astype(np.float32) / 65535.0 # Move the large image tensor to the selected device image_tensor = torch.from_numpy(image_float_np).to(device) transfer_time = time.perf_counter() print(f"Data transferred to {device.type} in {transfer_time - load_time:.2f} seconds.") # 4. Run the core algorithm on the device params = { 't_threshold': kwargs.get('t_threshold', DEFAULT_T_THRESHOLD), 'mu': kwargs.get('mu', DEFAULT_MU_STEP), 'a': kwargs.get('a', DEFAULT_A_THRESHOLD), 'b': kwargs.get('b', DEFAULT_B_THRESHOLD), 'max_iter': kwargs.get('max_iter', DEFAULT_MAX_ITERATIONS), } gains = huo_awb_core_torch(image_tensor, device=device, **params) process_time = time.perf_counter() print(f"AWB processing finished in {process_time - transfer_time:.2f} seconds.") # 5. Apply final gains, move back to CPU, and save corrected_image_tensor = torch.clamp(image_tensor * gains, 0.0, 1.0) # Move tensor back to CPU for conversion to NumPy corrected_image_np = corrected_image_tensor.cpu().numpy() # Convert back to 16-bit integer for saving corrected_image_uint16 = (corrected_image_np * 65535).astype(np.uint16) print(f"Saving corrected image to: {output_path}") imageio.imwrite(output_path, corrected_image_uint16) end_time = time.perf_counter() print(f"Image saved. Total time: {end_time - start_time:.2f} seconds.") # --- Example Usage --- if __name__ == '__main__': # Create a dummy 50MP, 16-bit TIFF with a bluish cast # 50MP is approx. 8660 x 5773 pixels h, w = 5773, 8660 print(f"Creating a sample {h*w/1e6:.1f}MP 16-bit TIFF image with a bluish cast...") # A gray gradient (create a smaller version and resize to save memory/time) small_w = w // 10 gray_base_small = np.linspace(0.2, 0.8, small_w, dtype=np.float32) gray_image_small = np.tile(gray_base_small, (h // 10, 1)) # Use PyTorch to resize efficiently if possible, otherwise numpy/scipy try: import torch.nn.functional as F gray_image = F.interpolate( torch.from_numpy(gray_image_small)[None, None, ...], size=(h, w), mode='bilinear', align_corners=False )[0, 0, ...].numpy() except (ImportError, ModuleNotFoundError): print("Resizing with a simpler method as full torch/cv2 not available for generation.") gray_image = np.tile(np.linspace(0.2, 0.8, w, dtype=np.float32), (h, 1)) image_float = np.stack([gray_image, gray_image, gray_image], axis=-1) # Apply a bluish cast (decrease R, increase B) blue_cast = np.array([0.85, 1.0, 1.15], dtype=np.float32) image_float_cast = np.clip(image_float * blue_cast, 0, 1) image_uint16_cast = (image_float_cast * 65535).astype(np.uint16) input_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz.tiff" output_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz_corrected.tiff" # Run the white balance algorithm on the sample image apply_huo_awb_torch(input_filename, output_filename)