#version 330 core
out vec4 FragColor;
in vec2 TexCoord;

uniform sampler2D InputTexture;
uniform float whitesValue; // -100 (pull down) to 100 (push up)
uniform float blacksValue; // -100 (lift up) to 100 (push down)

const vec3 luminanceWeight = vec3(0.2126, 0.7152, 0.0722);

// Helper function to preserve color relationships when adjusting luminance
vec3 preserveColor(vec3 color, float newLum) {
    float oldLum = dot(color, luminanceWeight);
    return oldLum > 0.0 ? color * (newLum / oldLum) : vec3(newLum);
}

vec3 applyWhitesBlacks(vec3 color, float whites, float blacks) {
    float lum = dot(color, luminanceWeight);
    
    // Map slider values to more appropriate adjustment strengths
    float whitesStrength = whites / 100.0;
    float blacksStrength = blacks / 100.0;
    
    // Create better perceptual masks with wider, smoother influence
    // Whites affect primarily highlights but have some influence into midtones
    float whiteMask = smoothstep(0.25, 1.0, lum);
    whiteMask = pow(whiteMask, 2.0 - max(0.0, whitesStrength)); // Dynamic falloff
    
    // Blacks affect primarily shadows but have some influence into midtones
    float blackMask = 1.0 - smoothstep(0.0, 0.5, lum);
    blackMask = pow(blackMask, 2.0 - max(0.0, -blacksStrength)); // Dynamic falloff
    
    // Calculate adjustment curves with proper toe/shoulder response
    float whitesAdj = 1.0 + whitesStrength * whiteMask * (1.0 - pow(1.0 - whiteMask, 3.0));
    float blacksAdj = 1.0 - blacksStrength * blackMask * (1.0 - pow(1.0 - blackMask, 3.0));
    
    // Apply adjustments with color preservation
    float adjustedLum = lum * whitesAdj * blacksAdj;
    adjustedLum = clamp(adjustedLum, 0.0, 2.0); // Allow some headroom for highlights
    
    // Preserve color relationships by scaling RGB proportionally
    vec3 result = preserveColor(color, adjustedLum);
    
    return result;
}

void main() {
    vec4 color = texture(InputTexture, TexCoord);
    color.rgb = applyWhitesBlacks(color.rgb, whitesValue, blacksValue);
    FragColor = vec4(max(color.rgb, vec3(0.0)), color.a); // Ensure non-negative
}