242 lines
9.3 KiB
Python
242 lines
9.3 KiB
Python
import torch
|
|
import numpy as np
|
|
import imageio.v2 as imageio
|
|
import time
|
|
import os
|
|
|
|
# --- Configuration & Constants ---
|
|
# These are the same as the NumPy version but will be used on PyTorch tensors.
|
|
DEFAULT_T_THRESHOLD = 0.1321
|
|
DEFAULT_MU_STEP = 0.0312
|
|
SCALE_FACTOR = 255.0
|
|
DEFAULT_A_THRESHOLD = 0.8 / SCALE_FACTOR
|
|
DEFAULT_B_THRESHOLD = 0.15 / SCALE_FACTOR
|
|
DEFAULT_MAX_ITERATIONS = 60
|
|
|
|
# --- PyTorch Core Algorithm Functions ---
|
|
|
|
def rgb_to_yuv_torch(image_tensor: torch.Tensor, matrix_torch: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Converts an RGB image tensor to the paper's YUV color space using PyTorch.
|
|
|
|
Args:
|
|
image_tensor (torch.Tensor): A (H, W, 3) tensor on a target device.
|
|
matrix_torch (torch.Tensor): The 3x3 conversion matrix on the same device.
|
|
|
|
Returns:
|
|
torch.Tensor: An (H, W, 3) YUV tensor on the same device.
|
|
"""
|
|
return torch.matmul(image_tensor, matrix_torch)
|
|
|
|
def k_function_torch(error: float, a: float, b: float) -> float:
|
|
"""
|
|
Implements the non-linear error weighting function K(x) from Eq. 16.
|
|
This function remains on the CPU as it operates on a single scalar value.
|
|
"""
|
|
abs_error = abs(error)
|
|
sign = np.sign(error)
|
|
|
|
if abs_error >= a:
|
|
return 2.0 * sign
|
|
elif abs_error >= b:
|
|
return 1.0 * sign
|
|
else:
|
|
return 0.0
|
|
|
|
def huo_awb_core_torch(image_tensor: torch.Tensor,
|
|
t_threshold: float,
|
|
mu: float,
|
|
a: float,
|
|
b: float,
|
|
max_iter: int,
|
|
device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Performs the core iterative AWB algorithm using PyTorch tensors on a specified device.
|
|
|
|
Args:
|
|
image_tensor (torch.Tensor): Input image as a (H, W, 3) float32 tensor on the target device.
|
|
(other params): Algorithm configuration constants.
|
|
device (torch.device): The device (e.g., 'cuda' or 'cpu') to run on.
|
|
|
|
Returns:
|
|
torch.Tensor: A (3,) tensor containing the final calculated [R, G, B] gains.
|
|
"""
|
|
# Create the YUV conversion matrix and gains tensor on the target device
|
|
yuv_matrix = torch.tensor([
|
|
[0.299, 0.587, 0.114],
|
|
[-0.299, -0.587, 0.886],
|
|
[0.701, -0.587, -0.114]
|
|
], dtype=torch.float32, device=device).T
|
|
|
|
gains = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32, device=device)
|
|
|
|
print(f"Starting iterative AWB on device: '{device.type}'...")
|
|
for i in range(max_iter):
|
|
# 1. Apply current gains to the image (all on GPU/device)
|
|
balanced_image = torch.clamp(image_tensor * gains, 0.0, 1.0)
|
|
|
|
# 2. Convert to YUV (on GPU/device)
|
|
yuv_image = rgb_to_yuv_torch(balanced_image, yuv_matrix)
|
|
Y, U, V = yuv_image.unbind(dim=-1)
|
|
|
|
# 3. Identify gray points (on GPU/device)
|
|
# Luminance mask to exclude overly dark or bright pixels
|
|
luminance_mask = (Y > 0.1) & (Y < 0.95)
|
|
|
|
if not torch.any(luminance_mask):
|
|
print(f"Iteration {i+1}: No pixels in luminance range. Stopping.")
|
|
break
|
|
|
|
Y_masked = Y[luminance_mask]
|
|
U_masked = U[luminance_mask]
|
|
V_masked = V[luminance_mask]
|
|
|
|
# Criterion from Eq. 10
|
|
gray_mask_indices = (torch.abs(U_masked) + torch.abs(V_masked)) / Y_masked < t_threshold
|
|
|
|
gray_points_U = U_masked[gray_mask_indices]
|
|
|
|
num_gray_points = gray_points_U.shape[0]
|
|
if num_gray_points < 50: # Use a higher threshold for large images
|
|
print(f"Iteration {i+1}: Not enough gray points found ({num_gray_points}). Stopping.")
|
|
break
|
|
|
|
# 4. Calculate average chrominance (reduction on GPU/device)
|
|
u_mean = torch.mean(gray_points_U)
|
|
v_mean = torch.mean(V_masked[gray_mask_indices])
|
|
|
|
# Bring the scalar results back to CPU for control flow
|
|
u_mean_cpu = u_mean.item()
|
|
v_mean_cpu = v_mean.item()
|
|
|
|
# Check for convergence
|
|
if abs(u_mean_cpu) < b and abs(v_mean_cpu) < b:
|
|
print(f"Iteration {i+1}: Converged. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}")
|
|
break
|
|
|
|
# 5. Determine adjustment (logic on CPU, gain update on GPU/device)
|
|
if abs(u_mean_cpu) > abs(v_mean_cpu):
|
|
error = -u_mean_cpu
|
|
adjustment = mu * k_function_torch(error, a, b)
|
|
gains[2] += adjustment
|
|
print(f"Iter {i+1}: Adjusting B-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, B-adj={adjustment:.4f}")
|
|
else:
|
|
error = -v_mean_cpu
|
|
adjustment = mu * k_function_torch(error, a, b)
|
|
gains[0] += adjustment
|
|
print(f"Iter {i+1}: Adjusting R-gain. u_mean={u_mean_cpu:.4f}, v_mean={v_mean_cpu:.4f}, R-adj={adjustment:.4f}")
|
|
|
|
print(f"Final gains: R={gains[0].item():.4f}, G={gains[1].item():.4f}, B={gains[2].item():.4f}")
|
|
return gains
|
|
|
|
# --- Main Public Function ---
|
|
|
|
def apply_huo_awb_torch(image_path: str, output_path: str, **kwargs):
|
|
"""
|
|
Loads a high-resolution 16-bit TIFF, applies the Huo et al. AWB algorithm
|
|
using PyTorch for high performance, and saves the result.
|
|
|
|
Args:
|
|
image_path (str): Path to the input 16-bit TIFF image.
|
|
output_path (str): Path to save the white-balanced 16-bit TIFF image.
|
|
**kwargs: Optional algorithm parameters.
|
|
"""
|
|
start_time = time.perf_counter()
|
|
|
|
# 1. Select Device (GPU if available, otherwise CPU)
|
|
# if torch.cuda.is_available():
|
|
# device = torch.device('cuda')
|
|
# elif torch.backends.mps.is_available(): # For Apple Silicon
|
|
# device = torch.device('mps')
|
|
# else:
|
|
device = torch.device('cpu')
|
|
print(f"Using device: {device}")
|
|
|
|
# 2. Load Image with imageio (on CPU)
|
|
print(f"Loading image from: {image_path}")
|
|
try:
|
|
image_np = imageio.imread(image_path)
|
|
except FileNotFoundError:
|
|
print(f"Error: The file '{image_path}' was not found.")
|
|
return
|
|
|
|
load_time = time.perf_counter()
|
|
print(f"Image loaded in {load_time - start_time:.2f} seconds.")
|
|
|
|
# 3. Pre-process and Move to Device
|
|
# Normalize to float32 and convert to PyTorch tensor
|
|
image_float_np = image_np.astype(np.float32) / 65535.0
|
|
# Move the large image tensor to the selected device
|
|
image_tensor = torch.from_numpy(image_float_np).to(device)
|
|
|
|
transfer_time = time.perf_counter()
|
|
print(f"Data transferred to {device.type} in {transfer_time - load_time:.2f} seconds.")
|
|
|
|
# 4. Run the core algorithm on the device
|
|
params = {
|
|
't_threshold': kwargs.get('t_threshold', DEFAULT_T_THRESHOLD),
|
|
'mu': kwargs.get('mu', DEFAULT_MU_STEP),
|
|
'a': kwargs.get('a', DEFAULT_A_THRESHOLD),
|
|
'b': kwargs.get('b', DEFAULT_B_THRESHOLD),
|
|
'max_iter': kwargs.get('max_iter', DEFAULT_MAX_ITERATIONS),
|
|
}
|
|
gains = huo_awb_core_torch(image_tensor, device=device, **params)
|
|
|
|
process_time = time.perf_counter()
|
|
print(f"AWB processing finished in {process_time - transfer_time:.2f} seconds.")
|
|
|
|
# 5. Apply final gains, move back to CPU, and save
|
|
corrected_image_tensor = torch.clamp(image_tensor * gains, 0.0, 1.0)
|
|
|
|
# Move tensor back to CPU for conversion to NumPy
|
|
corrected_image_np = corrected_image_tensor.cpu().numpy()
|
|
|
|
# Convert back to 16-bit integer for saving
|
|
corrected_image_uint16 = (corrected_image_np * 65535).astype(np.uint16)
|
|
|
|
print(f"Saving corrected image to: {output_path}")
|
|
imageio.imwrite(output_path, corrected_image_uint16)
|
|
|
|
end_time = time.perf_counter()
|
|
print(f"Image saved. Total time: {end_time - start_time:.2f} seconds.")
|
|
|
|
|
|
# --- Example Usage ---
|
|
|
|
if __name__ == '__main__':
|
|
# Create a dummy 50MP, 16-bit TIFF with a bluish cast
|
|
# 50MP is approx. 8660 x 5773 pixels
|
|
h, w = 5773, 8660
|
|
print(f"Creating a sample {h*w/1e6:.1f}MP 16-bit TIFF image with a bluish cast...")
|
|
|
|
# A gray gradient (create a smaller version and resize to save memory/time)
|
|
small_w = w // 10
|
|
gray_base_small = np.linspace(0.2, 0.8, small_w, dtype=np.float32)
|
|
gray_image_small = np.tile(gray_base_small, (h // 10, 1))
|
|
|
|
# Use PyTorch to resize efficiently if possible, otherwise numpy/scipy
|
|
try:
|
|
import torch.nn.functional as F
|
|
gray_image = F.interpolate(
|
|
torch.from_numpy(gray_image_small)[None, None, ...],
|
|
size=(h, w),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)[0, 0, ...].numpy()
|
|
except (ImportError, ModuleNotFoundError):
|
|
print("Resizing with a simpler method as full torch/cv2 not available for generation.")
|
|
gray_image = np.tile(np.linspace(0.2, 0.8, w, dtype=np.float32), (h, 1))
|
|
|
|
image_float = np.stack([gray_image, gray_image, gray_image], axis=-1)
|
|
|
|
# Apply a bluish cast (decrease R, increase B)
|
|
blue_cast = np.array([0.85, 1.0, 1.15], dtype=np.float32)
|
|
image_float_cast = np.clip(image_float * blue_cast, 0, 1)
|
|
|
|
image_uint16_cast = (image_float_cast * 65535).astype(np.uint16)
|
|
|
|
input_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz.tiff"
|
|
output_filename = "/home/dubey/projects/filmsim/test_images/v1.3output/filmscan/04_portra_400_border_v3colorxyz_corrected.tiff"
|
|
|
|
# Run the white balance algorithm on the sample image
|
|
apply_huo_awb_torch(input_filename, output_filename) |