from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
#constants & parameters
theta0 = 4.7*(10**(-5))
TD = 0.25*(np.log(theta0/(2*np.pi)))**2
d = 1.7*(10**(-30))
omega = 2.*np.pi*(1612.*10**(6))
eps = 8.85*10**(-12)
c = 3.*(10**8)
hbar = (6.626*10**(-34))/(2*np.pi)
eta = 0.01
nn = 10.**7
n = eta*nn
lambdaOH = c/(1612.*10**(6))
gamma = 1.282*(10.**(-11))
Tsp = 1./gamma
TR = 604800.
T1 = 210.*TR
T2 = 210.*TR
L = (Tsp/TR)*(8.*np.pi)/((3.*(lambdaOH**2.))*n)
radius = np.sqrt(lambdaOH*L/np.pi)
A = np.pi*(radius**2)
phi_diffraction = (lambdaOH**2)/A
V = (np.pi*(radius**2))*L
NN = n*V
constant3 = (omega*TR*NN*(d**2)/(2*c*hbar*eps*V))
Fn = 1.
Lp = 1.
Ldiff = Fn*L/0.35
#time
tmax = 500. #time limit
Ngrid = 500.
ht = tmax/Ngrid
interval = tmax/ht + 1 #i.e. [0,tmax] by interval ht
t = np.linspace(0,tmax,interval)
#z space
zmax = L/Lp
w = len(t) - 1 #interval number in z
z = np.linspace(0,zmax,len(t))
hz = zmax/w
Ep = np.zeros((len(z),len(t)),dtype=np.complex_) #number of rows = len(z)
Pp = np.zeros((len(z),len(t)),dtype=np.complex_)
N = np.zeros((len(z),len(t)),dtype=np.complex_) #using np.complex_ or complex works
Ep = np.zeros((len(z),len(t)),dtype=np.complex_) #number of rows = len(z)
Es = np.zeros((len(z),len(t)),dtype=np.complex_)
#Initial Conditions
Pp[:,0] = 0.5*np.cos(theta0)
N[:,0] = 0.5*np.sin(theta0)
#Boundary Conditions
Ep[0,:] = 0.025j
#----------------solving Bloch equations----------------
j = 0
i = 0
while i < (len(t)-1):
kNB = (1j)*(Pp[j,i]*Ep[j,i] - np.conj(Ep[j,i])*np.conj(Pp[j,i])) \
- N[j,i]/(T1/TR)
kPB = (2j)*(np.conj(Ep[j,i]))*N[j,i] - Pp[j,i]/(T2/TR)
bNB = (1j)*((Pp[j,i] + ht/2*kPB)*Ep[j,i] - \
np.conj(Ep[j,i])*np.conj(Pp[j,i] + ht/2*kPB)) - (N[j,i] + ht/2*kNB)/(T1/TR)
bPB = (2j)*(np.conj(Ep[j,i]))*(N[j,i] + ht/2*kNB) - (Pp[j,i] \
+ ht/2*kPB)/(T2/TR)
N[j,i+1] = N[j,i] + ht*bNB
Pp[j,i+1] = Pp[j,i] + ht*bPB
i = i + 1
k = 1
while k < (len(z)): #k goes from 1 and up to and including index 400
EpE = Ep
kpE = (1j*constant3)*np.conj(Pp[k-1]) - EpE[k-1]/(Ldiff/Lp)
bpE = (1j*constant3)*np.conj(Pp[k-1]) - (EpE[k-1] + hz/2*kpE)/(Ldiff/Lp)
EpE[k] = EpE[k-1] + hz*bpE
#----------------calculating NE & PpE----------------
PpE = Pp
NE = N
i = 0
while i < (len(t)-1):
kNBE = (1j)*(PpE[j,i]*EpE[j,i] - np.conj(EpE[j,i])*np.conj(PpE[j,i])) \
- NE[j,i]/(T1/TR)
kPBE = (2j)*(np.conj(EpE[j,i]))*NE[j,i] - PpE[j,i]/(T2/TR)
bNBE = (1j)*((PpE[j,i] + ht/2*kPBE)*EpE[j,i] - np.conj(EpE[j,i])*np.conj(PpE[j,i] \
+ ht/2*kPBE)) - (NE[j,i] + ht/2*kNBE)/(T1/TR)
bPBE = (2j)*(np.conj(EpE[j,i]))*(NE[j,i] + ht/2*kNBE) - (PpE[j,i] \
+ ht/2*kPBE)/(T2/TR)
NE[k,i+1] = NE[k,i] + ht*bNBE
PpE[k,i+1] = PpE[k,i] + ht*bPBE
i = i + 1
NA = N
NA[k,:] = (NE[k,:]+N[k-1,:])/2
PpA = Pp
PpA[k,:] = (PpE[k,:]+Pp[k-1,:])/2
EpM = EpE #Ep field for this step
kpM = (1j*constant3)*np.conj(PpA[k-1]) - EpM[k-1]/(Ldiff/Lp)
bpM = (1j*constant3)*np.conj(PpA[k-1]) - (EpM[k-1] + hz/2*kpM)/(Ldiff/Lp)
EpM[k,:] = EpM[k-1,:] + hz*bpM
NM = NA #final N
PpM = PpA #final Pp
i = 0
while i < (len(t)-1):
kNBM = (1j)*(PpM[j,i]*EpM[j,i] - np.conj(EpM[j,i])*np.conj(PpM[j,i])) \
- NM[j,i]/(T1/TR)
kPBM = (2j)*(np.conj(EpM[j,i]))*NM[j,i] - PpM[j,i]/(T2/TR)
bNBM = (1j)*((PpM[j,i] + ht/2*kPBM)*EpM[j,i] - np.conj(EpM[j,i])*np.conj(PpM[j,i] \
+ ht/2*kPBM)) - (NM[j,i] + ht/2*kNBM)/(T1/TR)
bPBM = (1j)*(np.conj(EpM[j,i]))*(NM[j,i] + ht/2*kNBE) - (PpM[j,i] \
+ ht/2*kPBM)/(T2/TR)
NM[k,i+1] = NM[k,i] + ht*bNBM
PpM[k,i+1] = PpM[k,i] + ht*bPBM
i = i + 1
Ep[k,:] = EpM[k,:]
N[k,:] = NM[k,:]
Pp[k,:] = PpM[k,:]
k = k + 1
fig = plt.figure(1,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,(N*np.conj(N)).real, rstride=4, cstride=4, alpha=0.1)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.set_zlabel('Population Inversion (scaled)')
ax.view_init(30, 45)
plt.show()
I_SR_scaled = (0.5*c*eps*(Ep*np.conj(Ep))).real # = Itot
I_SR = I_SR_scaled/((d*TR/hbar)**2)
I_nc = NN*hbar*omega*(1./(A*Tsp))*(phi_diffraction/(4.*np.pi))
IntensityRatio = (I_SR/(NN*I_nc))/0.001
fig = plt.figure(2,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,IntensityRatio, rstride=4, cstride=4, alpha=0.1)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.set_zlabel(r"$ \frac {I_{SR}}{NI_{nc}}$")
ax.view_init(45, 45)
plt.show()