/********************************************************
    © 2020 Continuum Graphics LLC. All Rights Reserved
 ********************************************************/

#if !defined _WATERVOLUME_
#define _WATERVOLUME_

/* Temporary water volume, since we plan on revisiting this at some point later. */

float WaterPhaseG(float cosTheta, const float g){
    float gg = g * g;
    return rPI * (gg * -0.25 + 0.25) * pow(-2.0 * (g * cosTheta) + (gg + 1.0), -1.5);
}

vec3 TransmittedScatteringIntegral(const float opticalDepth, const vec3 coeff) {
    const vec3 a = -coeff / log(2.0);
    const vec3 b = -1.0 / coeff;
    const vec3 c =  1.0 / coeff;

    return exp2(a * opticalDepth * rLOG2) * b + c;
}

float calculateHardShadows(float shadowDepth, vec3 shadowPosition, float bias) {
    if(shadowPosition.z >= 1.0) return 1.0;

    return 1.0 - fstep(shadowDepth, shadowPosition.z - bias);
}

vec3 calculateWaterTransmit(vec3 worldPosition, float shadowWaterMask, float depth0, out float waterDepth){
    waterDepth = (depth0 * 8.0 - 4.0);
        waterDepth = waterDepth * shadowProjectionInverse[2].z + shadowProjectionInverse[3].z;
        waterDepth = (waterDepth - transMAD(shadowModelViewCustom, worldPosition).z);

        waterDepth = mix(0.0, waterDepth, shadowWaterMask * fstep(0.0, waterDepth));
    
    return exp2(-waterTransmitCoefficient * waterDepth * rLOG2);
}

void calculateVolumetricLightScatteringWater(vec3 position, vec3 shadowPosition, vec3 transmit, inout vec3 directScattering, inout vec3 indirectScattering){
    shadowPosition.xy = DistortShadowSpaceProj(shadowPosition.xy);

    float shadowDepth0 = texture2DLod(shadowtex0, shadowPosition.xy, 0).x;
    float shadowDepth1 = texture2DLod(shadowtex1, shadowPosition.xy, 0).x;
    vec4 shadowColor1 = texture2DLod(shadowcolor1, shadowPosition.xy, 0);

    float shadowWaterMask = shadowColor1.a * 2.0 - 1.0;

    float waterDepth = 0.0;

    float volumetricShadow = calculateHardShadows(shadowDepth1, shadowPosition, 0.0);
    vec3 waterTransmit = calculateWaterTransmit(position, shadowWaterMask, shadowDepth0, waterDepth);

    
    #ifdef WATER_CAUSTICS
        #ifdef VOLUMETRIC_WATER_CAUSTICS
            float causticDis = length(position - gbufferModelViewInverse[3].xyz);
            const float maxDist = 20.0;

            if (causticDis < maxDist) {
                vec4 shadowColor0 = texture2DLod(shadowcolor, shadowPosition.xy, 0);

                float caustics = pow(shadowColor0.r, 2.2) * 100.0 * float(shadowWaterMask > 0.999);
                waterTransmit.rgb *= mix(1.0, caustics, shadowWaterMask * clamp01((maxDist - causticDis) * 2.0 / maxDist));
            }
        #endif
    #endif
    

    directScattering += volumetricShadow * transmit * waterTransmit;
    indirectScattering += transmit;
}

vec3 GetSunAndSkyIrradianceForWaterVolume(AtmosphereParameters atmosphere, sampler3D transmit_texture, sampler3D irradiance_texture, vec3 point, vec3 normal, vec3 sun_direction, vec3 moon_direction, out vec3 sky_irradiance, out vec3 sky_irradiance_moon, out vec3 moon_irradiance) {
    float r = length(point);
    float mu_s = dot(point, sun_direction) / r;
    float mu_n = dot(point, moon_direction) / r;

    float NoP = dot(normal, point) / r;

    float diff = ((1.0 - NoP) * rTAU + NoP + 1.0) * 0.5;

    // Indirect irradiance (approximated if the surface is not horizontal).
    sky_irradiance_moon = vec3(0.0);
    sky_irradiance = GetIrradiance(atmosphere, irradiance_texture, r, mu_s, mu_n, sky_irradiance_moon) * diff;
    sky_irradiance_moon = sky_irradiance_moon * diff;

    moon_irradiance = atmosphere.lunar_irradiance * GetTransmitToSun(atmosphere, transmit_texture, r, -mu_s);

    // Direct irradiance.
    return atmosphere.solar_irradiance * GetTransmitToSun(atmosphere, transmit_texture, r, mu_s);
}

vec3 CalculateWaterVolume(vec3 background, vec3 start, vec3 end, float dither, float cosTheta, float skylightOcclusion, out vec3 scatter, out vec3 transmit) {
#ifndef UNDERWATER_VOLUMETRIC_LIGHT
    return background;
#endif
    
    const int   steps  = UNDERWATER_VOLUMETRIC_LIGHT_QUALITY;
    const float rSteps = 1.0 / steps;

    vec3 worldIncrement = (end - start) * rSteps;
    vec3 worldPosition  = dither * worldIncrement + start;
         //worldPosition += cameraPosition;

    float opticalDepth = length(worldIncrement);

    vec3 scatterCoeff = waterScatterCoefficient * TransmittedScatteringIntegral(opticalDepth, waterAbsorptionCoefficient);
    vec3 stepTransmit = exp2(-waterTransmitCoefficient * opticalDepth * rLOG2);

    vec3 shadowStart     = WorldSpaceToShadowSpace(start);
    vec3 shadowIncrement = (WorldSpaceToShadowSpace(end) - shadowStart) * rSteps;
    vec3 shadowPosition  = dither * shadowIncrement + shadowStart;

    vec3 irradianceLookupStart = isEyeInWater == 0 ? start : end;
    vec3 skyIrradiance = skyIlluminanceVert;
    vec3 sunIrradiance = sunIlluminanceVert;

    transmit = vec3(1.0);

    vec3 directScattering = vec3(0.0);
    vec3 indirectScattering = vec3(0.0);

    for(int i = 0; i < steps; ++i, worldPosition += worldIncrement, shadowPosition += shadowIncrement) {
        calculateVolumetricLightScatteringWater(worldPosition, shadowPosition, transmit, directScattering, indirectScattering);

        transmit *= stepTransmit;
    }

    float phase = WaterPhaseG(cosTheta, 0.5);

    vec3 directLighting = sunIrradiance * directScattering * phase;
    vec3 indirectLighting = skyIrradiance * 0.25 * rPI * indirectScattering * skylightOcclusion;

    scatter = (directLighting + indirectLighting) * scatterCoeff;

    return background * transmit + scatter;
}

#endif
