import numpy as np


# Parameters of this simple experiment
x0 = -1.0
y0 = -2.0
z0 = -3.0
R = 1.0

bins_y = 100
bins_z = 100

length_y = 10.0
length_z = 10.0

spacing_y = length_y / bins_y
spacing_z = length_z / bins_z

begin_y = -length_y/2
begin_z = -length_z/2

def make_mc_histogram(num_points):
    # First seed the generator
    np.random.seed(6112)

    # Sample uniform random directions with Muller's algorithm
    # See equation (16) here: https://mathworld.wolfram.com/SpherePointPicking.html
    directions = np.random.randn(3,num_points)

    # Normalize as per Muller's algorithm
    lengths = np.linalg.norm(directions, axis=0)

    directions /= lengths

    # All of them should be unit vectors
    #assert np.all(np.abs(np.linalg.norm(directions,axis=0) - 1) < 1e-5)

    # R = x = x0 + direction_x*distance
    # y = y0 + direction_y*distance
    # z = z0 + direction_z*distance

    distance = (R - x0)/directions[0,:]

    # We can't go a negative distance
    mask = (distance > 0)

    distance = distance[mask]
    directions = directions[:, mask]

    source_point = np.array([x0,y0,z0])

    hit_points = source_point + (distance*directions).T

    #print(hit_points.shape)

    # All x hit points should be close to R
    #assert np.all(np.abs(hit_points[:,0] - R) < 1e-5*R)

    # Bin (y,z) coordinates
    #return np.histogram2d(hit_points[:,1], hit_points[:,2], bins=(bins_y,bins_z), range=[[begin_y, begin_y+length_y],[begin_z, begin_z+length_z]], density=True)[0]
    return np.histogram2d(hit_points[:,1], hit_points[:,2], bins=(bins_y,bins_z), range=[[begin_y, begin_y+length_y],[begin_z, begin_z+length_z]])[0]

def make_density():
    # Make grid
    y = begin_y + (np.arange(bins_y) + 0.5)*spacing_y
    z = begin_z + (np.arange(bins_z) + 0.5)*spacing_z

    # Add singleton dimension to end of y
    y = y[:,None]

    # Add singleton dimension to begining of z
    z = z[None,:]

    density = 1.0/(4.0*np.pi) * np.abs(R - x0)/np.sqrt((R - x0)**2 + (y - y0)**2 + (z - z0)**2)**3

    return density

def make_density_invsquare():
    y = begin_y + (np.arange(bins_y) + 0.5)*spacing_y
    z = begin_z + (np.arange(bins_z) + 0.5)*spacing_z

    # Add singleton dimension to end of y
    y = y[:,None]

    # Add singleton dimension to begining of z
    z = z[None,:]

    density = 1.0/np.sqrt((R - x0)**2 + (y - y0)**2 + (z - z0)**2)**2

    return density

def KL(p,q):
    mask = np.logical_and(p > 0, q > 0)
    p = p[mask]
    q = q[mask]

    return (p*np.log(p/q)).sum()
   
for num_points in [1e3,1e4,1e5,1e6,1e7,1e8]:
    num_points = int(num_points)

    counts = make_mc_histogram(num_points)
    density = make_density()
    density_invsq = make_density_invsquare()

    normalized_counts = counts / counts.sum()
    density /= density.sum()
    density_invsq /= density_invsq.sum()

    mse = ((normalized_counts - density)**2).mean()
    mse2 = ((normalized_counts - density_invsq)**2).mean()

    kl = KL(normalized_counts, density)
    kl2 = KL(normalized_counts, density_invsq)

    print(f"num_points = {num_points}, mse = {mse}, mse_invsq = {mse2}, kl = {kl}, kl_invsq = {kl2}")


