import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import spsolve
from mpl_toolkits.mplot3d import Axes3D

# Parameters
L = 15  # Domain length (m)
Nx = 500   # Number of points in x
Nt = 200  # Number of time steps
dx = L / Nx  # Spatial step
dt = 1   # Time step
D = 0.012    # Diffusion coefficient
gamma_I = 0.06  # Fission yield for Iodine
gamma_X = 0.003 # Fission yield for Xenon
lambda_I = 2.9e-5  # Decay of Iodine
lambda_X = 2.1e-5  # Decay of Xenon
nu = 2.47 # Number of neutrons per fission

sigma_X = 2.76e-18  # Cross section for Xenon absorption
sigma_f_5 = 580e-23
sigma_a_5 = 680e-23
sigma_a_8 = 2.7e-23
sigma_a_bore = 2840e-23

Phi_0 = 3e13

# Equilibrium value before imbalance
rho_UO2 = 10
Na = 6.02214076e23 # Avogadro's number
eta = 0.15 # Fuel to total volume ratio
rich = 0.03 # Fuel enrichment
M_O2 = 32.00     # O2 molecule (O2)
M_U235 = 235.0439  # Uranium-235
M_U238 = 238.0289  # Uranium-238
M_UO2 = 0.03 * M_U235 + 0.97 * M_U238 + M_O2 # Molar mass of UO2

N_U5 = 0.03 * (rho_UO2 * Na * eta / M_UO2)
N_U8 = 0.97 * (rho_UO2 * Na * eta / M_UO2)

# Macroscopic cross section and Keff
Sigma_a_barre = 0.0304
Sigma_a_eau = 0.04
Sigma_a_bore = 0.5682434270366046
Sigma_f = N_U5 * sigma_f_5
Sigma_a = N_U8 * sigma_a_8 + N_U5 * sigma_a_5 + Sigma_a_barre + Sigma_a_eau + Sigma_a_bore

N_I_eq = gamma_I * Sigma_f * Phi_0 / lambda_I
N_X_eq = (gamma_I + gamma_X) * Sigma_f * Phi_0 / (lambda_X + sigma_X * Phi_0)

# Keff
K = nu * Sigma_f / (Sigma_a + sigma_X * N_X_eq)
print(K)

# Grid initialization
x = np.linspace(0, L, Nx)
perturbation = 0.0001 * np.sin(np.pi * x * 2 / L).reshape(-1, 1)
Phi = Phi_0 * (1 + perturbation)  # Gentle imbalance

N_I = np.full((Nx, 1), N_I_eq)  # Iodine concentration
N_X = np.full((Nx, 1), N_X_eq)  # Xenon concentration

# Calculate residues at initial time
N_I_pred = N_I + dt * (gamma_I * Sigma_f * Phi - lambda_I * N_I)
N_X_pred = N_X + dt * (gamma_X * Sigma_f * Phi + lambda_I * N_I - (lambda_X + sigma_X * Phi) * N_X)
RI_0 = N_I_pred - N_I - dt * (gamma_I * Sigma_f * Phi - lambda_I * N_I_pred)
RX_0 = N_X_pred - N_X - dt * (gamma_X * Sigma_f * Phi + lambda_I * N_I_pred - (lambda_X + sigma_X * Phi) * N_X_pred)

def cal_A(N_X):
    A = np.zeros((Nx, Nx))

    # Left boundary: ∂Φ/∂x = 0 (second-order centered scheme)
    A[0][0] = 3 * D / (2 * dx ** 2) + Sigma_a + sigma_X * N_X[0]
    A[0][1] = -2 * D / (dx ** 2)
    A[0][2] = D / (2 * dx ** 2)
    # Right boundary: ∂Φ/∂x = 0
    A[-1][-1] = 3 * D / (2 * dx ** 2) + Sigma_a + sigma_X * N_X[-1]
    A[-1][-2] = -2 * D / (dx ** 2)
    A[-1][-3] = D / (2 * dx ** 2)

    for i in range(1, Nx - 1):
        a = 2 * D / (dx ** 2) + Sigma_a + sigma_X * N_X[i]
        b = -D / (dx ** 2)

        A[i][i] = a
        A[i][i - 1] = b
        A[i][i + 1] = b

    return A

A = cal_A(N_X)

RPhi_0 = np.dot(A, Phi) - nu * Sigma_f * Phi / K

R_0 = np.concatenate([RI_0, RX_0, RPhi_0])

def J(N_X, Phi):
    # Convert Phi to a 1D vector (Nx,) instead of (Nx, 1)
    Phi_flat = Phi.flatten()  # or Phi.squeeze()

    M1 = np.diag(np.full(Nx, -dt * gamma_I * Sigma_f))
    M2 = np.diag(np.full(Nx, 1 + lambda_I * dt))
    M3 = np.zeros((Nx, Nx))
    M4 = np.diag(np.full(Nx, -dt * gamma_X * Sigma_f))
    M5 = np.diag(np.full(Nx, -lambda_I * dt))
    M6 = np.diag(1 + dt * (lambda_X + sigma_X * Phi_flat))  # Now compatible
    M7 = A
    M8 = np.zeros((Nx, Nx))
    M9 = np.diag(dt * sigma_X * Phi_flat)  # Also fixed

    J = np.block([[M1, M2, M3], [M4, M5, M6], [M7, M8, M9]])
    return J

J_dense = J(N_X, Phi)

J_sparse = csr_matrix(J_dense)

delta_x0 = spsolve(J_sparse, -R_0)

X = np.concatenate([N_I, N_X, Phi])

# Store X over time
L_X = []
L_X.append(X)

X = X + delta_x0
L_X.append(X)

# Time loop
def extract(X):
    # Ensure X is 1D
    X = X.reshape(-1)  # Converts to (1500,)

    N_I = X[:Nx]       # (500,)
    N_X = X[Nx:2 * Nx]   # (500,)
    Phi = X[2 * Nx:3 * Nx] # (500,)

    return N_I, N_X, Phi

for i in range(Nt):
    N_I, N_X, Phi = extract(X)

    X_prec = L_X[i]
    N_I_prec, N_X_prec, Phi_prec = extract(X_prec)

    # Calculate new keff
    Sigma_a_total = Sigma_a + sigma_X * N_X  # Total macroscopic cross section
    keff = nu * Sigma_f / np.mean(Sigma_a_total)  # Average over the entire domain

    Phi = Phi / keff

    print(keff)

    RI_0 = N_I - N_I_prec - dt * (gamma_I * Sigma_f * Phi - lambda_I * N_I)
    RX_0 = N_X - N_X_prec - dt * (gamma_X * Sigma_f * Phi + lambda_I * N_I - (lambda_X + sigma_X * Phi) * N_X)

    A = cal_A(N_X)

    RPhi_0 = np.dot(A, Phi) - nu * Sigma_f * Phi

    R_0 = np.concatenate([RI_0, RX_0, RPhi_0])

    J_dense = J(N_X, Phi)

    J_sparse = csr_matrix(J_dense)

    delta_x0 = spsolve(J_sparse, -R_0)

    X = X + delta_x0

    L_X.append(X)

def plot_3d_results(L_X, Nx, Nt, L):
    """
    Plots concentrations and flux in 3D as a function of time and space

    Args:
        L_X (list): List of state vectors at each time step
        Nx (int): Number of spatial points
        Nt (int): Number of time steps
        L (float): Domain length
    """
    # Prepare data
    x = np.linspace(0, L, Nx)
    t = np.arange(len(L_X))  # Use actual length of L_X

    # Create grids for 3D plot
    X, T = np.meshgrid(x, t)

    # Extract data with shape conversion
    N_I_3d = np.zeros((len(L_X), Nx))
    N_X_3d = np.zeros((len(L_X), Nx))
    Phi_3d = np.zeros((len(L_X), Nx))

    for i, state in enumerate(L_X):
        # Convert to 1D array if necessary and extract
        state_flat = np.array(state).flatten()
        N_I = state_flat[:Nx]
        N_X = state_flat[Nx:2 * Nx]
        Phi = state_flat[2 * Nx:3 * Nx]

        N_I_3d[i, :] = N_I
        N_X_3d[i, :] = N_X
        Phi_3d[i, :] = Phi

    # Normalize for better visualization
    Phi_3d_norm = Phi_3d / np.max(Phi_3d)

    # Create figures
    fig = plt.figure(figsize=(18, 12))

    # Plot 1: Iodine concentration
    ax1 = fig.add_subplot(311, projection='3d')
    surf1 = ax1.plot_surface(X, T, N_I_3d, cmap='viridis')
    ax1.set_title('Iodine Concentration vs Time and Space')
    ax1.set_xlabel('Position (m)')
    ax1.set_ylabel('Time (s)')
    ax1.set_zlabel('Iodine Concentration (at/m³)')
    fig.colorbar(surf1, ax=ax1, shrink=0.5, aspect=5)

    # Plot 2: Xenon concentration
    ax2 = fig.add_subplot(312, projection='3d')
    surf2 = ax2.plot_surface(X, T, N_X_3d, cmap='plasma')
    ax2.set_title('Xenon Concentration vs Time and Space')
    ax2.set_xlabel('Position (m)')
    ax2.set_ylabel('Time (s)')
    ax2.set_zlabel('Xenon Concentration (at/m³)')
    fig.colorbar(surf2, ax=ax2, shrink=0.5, aspect=5)

    # Plot 3: Normalized neutron flux
    ax3 = fig.add_subplot(313, projection='3d')
    surf3 = ax3.plot_surface(X, T, Phi_3d_norm, cmap='inferno')
    ax3.set_title('Normalized Neutron Flux vs Time and Space')
    ax3.set_xlabel('Position (m)')
    ax3.set_ylabel('Time (s)')
    ax3.set_zlabel('Normalized Flux')
    fig.colorbar(surf3, ax=ax3, shrink=0.5, aspect=5)

    plt.tight_layout()
    plt.show()

    # Additional plot: Temporal evolution at x=L/2
    mid_idx = Nx // 2
    fig2, (ax4, ax5) = plt.subplots(2, 1, figsize=(10, 8))

    ax4.plot(t, N_I_3d[:, mid_idx], label='Iodine')
    ax4.plot(t, N_X_3d[:, mid_idx], label='Xenon')
    ax4.set_title('Concentration Evolution at Reactor Center')
    ax4.set_xlabel('Time (s)')
    ax4.set_ylabel('Concentration (at/m³)')
    ax4.legend()
    ax4.grid()

    ax5.plot(t, Phi_3d[:, mid_idx])
    ax5.set_title('Neutron Flux Evolution at Reactor Center')
    ax5.set_xlabel('Time (s)')
    ax5.set_ylabel('Neutron Flux (n/cm²/s)')
    ax5.grid()

    plt.tight_layout()
    plt.show()

# Usage:
plot_3d_results(L_X, Nx, Nt, L)  