/********************************************************
    © 2020 Continuum Graphics LLC. All Rights Reserved
 ********************************************************/

#if !defined _SCREENSPACERAYTRACER_
#define _SCREENSPACERAYTRACER_

const float zThicknessThreshold = 0.1; // thickness to ascribe to each pixel in the depth buffer
const float refinementSize = 0.5;

/*
const int binarySearchSteps = 4;

void binarySearch(vec3 rayDir, float depth, inout vec3 rayPosition, inout bool intersect){

    for (int i = 0; i < binarySearchSteps; ++i){

        float linearZ = ScreenToViewSpaceDepth(rayPosition.z);
        float linearD = ScreenToViewSpaceDepth(depth);
        
        float dist = abs(linearD - linearZ) / -linearZ;
        
        if (dist < zThicknessThreshold && linearZ < 0.0 && linearZ > -far) {
            intersect = true;
        }

        if (linearD - linearZ < 0.0) rayPosition += rayDir;

        rayDir *= 0.5;
        rayPosition -= rayDir;

        depth = texture2D(colortex0, rayPosition.xy).x;
    }
}
*/

vec3 rayTrace(vec3 rayOrigin, vec3 rayDir, float NoV, float jitter, vec3 hitPixel, vec3 sky, float skyLightmap) {
    const float maxLength = 1.0 / RAYTRACE_QUALITY;
    const float minLength = maxLength * 0.01;

    float maxDist = far * sqrt(3.);

	float rayLength = ((rayOrigin.z + rayDir.z * maxDist) > -near) ?
      	 			  (-near - rayOrigin.z) / rayDir.z : maxDist;

	vec3 direction = normalize(ViewSpaceToScreenSpace(rayDir * rayLength + rayOrigin) - hitPixel);
    float stepWeight = 1.0 / abs(direction.z);

	float stepLength = mix(minLength, maxLength, NoV);
    vec3 increment = direction * vec3(max(pixelSize, stepLength), stepLength);

	vec3 rayPosition = hitPixel + increment * (jitter * 0.5 + 0.5);

	float depth = texture2D(depthtex0, rayPosition.xy).x;

    bool intersect = false;
    //bool isRayExit = true;

    bool needRefinement = true;

	for(int i = 0; i < RAYTRACE_QUALITY + 4; i++){
		if (clamp01(rayPosition.xy) != rayPosition.xy) return sky * skyLightmap;

        if (depth < rayPosition.z) {
            depth = texture2D(depthtex1, rayPosition.xy).x;
        }
        
        if (depth < rayPosition.z) {

            float linearZ = ScreenToViewSpaceDepth(rayPosition.z);
            float linearD = ScreenToViewSpaceDepth(depth);

            float dist = abs(linearD - linearZ) / -linearZ;
            
            // ray refinement
            #ifdef RAYTRACE_REFINEMENT
                if (needRefinement) {
                    rayPosition -= direction * stepLength * 0.5;
                    depth = texture2D(depthtex1, rayPosition.xy).x;

                    if (rayPosition.z >= 1.0) {
                        break;
                    }

                    for (int j = 0; j < RAYTRACE_REFINEMENT_STEPS; j++) {
                        float linearZ = ScreenToViewSpaceDepth(rayPosition.z);
                        float linearD = ScreenToViewSpaceDepth(depth);

                        float dist = abs(linearD - linearZ) / -linearZ;

                        if (dist < zThicknessThreshold && linearZ < 0.0 && linearZ > -far) {
                            break;
                        }
                        
                        float refinedStepLength = clamp(abs(depth - rayPosition.z) * stepWeight, minLength, maxLength);
                        rayPosition += direction * refinedStepLength * refinementSize;
                        depth = texture2D(depthtex1, rayPosition.xy).x;
                    }
                    
                    needRefinement = false;
                    continue;
                }
            #endif

            // Check if the current ray has an intersection with the scene
            if (dist < zThicknessThreshold && linearZ < 0.0 && linearZ > -far) {
                intersect = true; 
                break;
            }
        }

        stepLength = clamp(abs(depth - rayPosition.z) * stepWeight, minLength, maxLength);
		rayPosition += direction * stepLength;
		depth = texture2D(depthtex0, rayPosition.xy).x;
	}

	if (depth >= 1.0 /*|| isRayExit*/) return sky;

	return intersect ? DecodeRGBE8(texture(colortex0, rayPosition.xy)) : sky * skyLightmap;
}

/*
const float pixelStride = 10.0; // Step in horizontal or vertical pixels between samples. This is a float
const int MAX_ITERATION = 128; // Maximum number of iterations. Higher gives better images but may be slow.
const float maxRayDistance = 512.0; // Maximum camera-space distance to trace before returning a miss.

float linearDepthTexelFetch(ivec2 hitPixel) {
    // Load returns 0 for any value accessed out of bounds
    return ScreenToViewSpaceDepth(texelFetch(depthtex1, hitPixel, 0).r);
}

bool rayIntersectDepth(float rayZNear, float rayZFar, vec2 hitPixel) {
    // Swap if bigger
    if (rayZFar > rayZNear) {
        float t = rayZFar; rayZFar = rayZNear; rayZNear = t;
    }

    float cameraZ = linearDepthTexelFetch(ivec2(hitPixel));

    return rayZFar < cameraZ && rayZNear >= cameraZ - zThicknessThreshold && rayZNear < 0.0 && cameraZ > (-far); // Cross z
}

// Returns true if the ray hit something
bool traceScreenSpaceRay(
    // Camera-space ray origin, which must be within the view volume
    vec3 rayOrigin,
    // Unit length camera-space ray direction
    vec3 rayDir,
    // Number between 0 and 1 for how far to bump the ray in stride units
    // to conceal banding artifacts. Not needed if stride == 1.
    float jitter,
    // Pixel coordinates of the first intersection with the scene
    out vec2 hitPixel,
    // Camera space location of the ray hit
    out vec3 hitPoint)
{

    // Clip to the near plane
    float rayLength = ((rayOrigin.z + rayDir.z * maxRayDistance) > -near)
        ? (-near - rayOrigin.z) / rayDir.z : maxRayDistance;

    vec3 rayEnd = rayOrigin + rayDir * rayLength;

    // Project into homogeneous clip space
    vec4 H0 = projMatrix * vec4(rayOrigin, 1.0);
    vec4 H1 = projMatrix * vec4(rayEnd, 1.0);

    float k0 = 1.0 / H0.w, k1 = 1.0 / H1.w;

    // The interpolated homogeneous version of the camera space points
    vec3 Q0 = rayOrigin * k0, Q1 = rayEnd * k1;

    // Screen space endpoints
    // PENDING viewportSize ?
    vec2 P0 = ((H0.xy * k0 * 0.5 + 0.5) + taaJitter * 0.5) * viewDimensions;
    vec2 P1 = ((H1.xy * k1 * 0.5 + 0.5) + taaJitter * 0.5) * viewDimensions;

    // If the line is degenerate, make it cover at least one pixel to avoid handling
    // zero-pixel extent as a special case later
    P1 += dot(P1 - P0, P1 - P0) < 0.0001 ? 0.01 : 0.0;
    vec2 delta = P1 - P0;

    // Permute so that the primary iteration is in x to collapse
    // all quadrant-specific DDA case later
    bool permute = false;
    if (abs(delta.x) < abs(delta.y)) {
        // More vertical line
        permute = true;
        delta = delta.yx;
        P0 = P0.yx;
        P1 = P1.yx;
    }

    float stepDir = sign(delta.x);
    float invdx = stepDir / delta.x;

    // Track the derivatives of Q and K
    vec3 dQ = (Q1 - Q0) * invdx;
    float dk = (k1 - k0) * invdx;
    vec2 dP = vec2(stepDir, delta.y * invdx);

    float pixStride = 1.0 + pixelStride;

    // Scale derivatives by the desired pixel stride and the offset the starting values by the jitter fraction
    dP *= pixStride; dQ *= pixStride; dk *= pixStride;

    // Track ray step and derivatives in a vec4 to parallelize
    vec4 pqk = vec4(P0, Q0.z, k0);
    vec4 dPQK = vec4(dP, dQ.z, dk);

    pqk += dPQK * jitter;

    float rayZFar = (dPQK.z * 0.5 + pqk.z) / (dPQK.w * 0.5 + pqk.w);
    float rayZNear;

    bool intersect = false;

    float iterationCount = 0.0;

    for (int i = 0; i < MAX_ITERATION; i++) {
        pqk += dPQK;

        rayZNear = rayZFar;
        rayZFar = (dPQK.z * 0.5 + pqk.z) / (dPQK.w * 0.5 + pqk.w);

        hitPixel = permute ? pqk.yx : pqk.xy;

        intersect = rayIntersectDepth(rayZNear, rayZFar, hitPixel);
        iterationCount += 1.0;

        if (intersect || any(greaterThan(hitPixel, vec2(viewDimensions))) || any(lessThan(hitPixel, vec2(0.0)))) {
            break;
        }
    }

    Q0.xy += dQ.xy * iterationCount;
    Q0.z = pqk.z;
    hitPoint = Q0 / pqk.w;
    hitPixel *= pixelSize;

    return intersect;
}
*/

#endif
