# -*- coding: utf-8 -*-
"""
Created on Thu May 26 13:31:05 2016

@author: Abhilash
"""
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.colors import Normalize
from pylab import *
import numpy as np

#from time import sleep
#from tqdm import tqdm

#constants
hbar= 1.
r12 = 1*hbar
rh = r12/hbar
speed = 1.
wp = 1.
ws = 1.
NP = 1. #changed from np in original code
ns = 1.
N = 1.
v = 1. #velocity

#parameters
T1 = 1000.
T2 = 5.
n0 = -1.
alphap = 0.
alphas = 0.

Ap = 1.
As = 0.01*Ap
tFWHM = 15.
 
#time
ht = 0.1 #time interval
tmax = 20. #time limit
interval = 2*tmax/ht + 1 #i.e. [-tmax,tmax] by interval ht
t = np.linspace(-tmax,tmax,interval)

#z space
zmax = 2.
w = len(t) - 1 #interval number in z
z = np.linspace(0,zmax,len(t))
hz = zmax/w

#tau - moving frame time
tau = t - z/v

#initial conditions
Ep = np.zeros((len(z),len(t)),dtype=np.complex_) #number of rows = len(z)
Es = np.zeros((len(z),len(t)),dtype=np.complex_) #using np.complex_ or complex seems to work

Ep[0] = (Ap/2.)*(np.e**(-2.*np.log(2)*((t/tFWHM)**2)))
Es[0] = (As/2.)*(np.e**(-2.*np.log(2)*((t/tFWHM)**2)))

#Ep[0] = Ap/np.cosh(t/17.6)
#Es[0] = Ap/np.cosh(t/17.6)

B1 = 1j*2.*np.pi*N*r12*wp/(speed*NP)
B2 = 1j*2.*np.pi*N*r12*ws/(speed*ns)

#----------------solving Bloch equations----------------

Ro = np.zeros((len(z),len(tau)),dtype=np.complex_)
n = np.zeros((len(z),len(tau)),dtype=np.complex_)
Ro[1,1] = 0
n[:] = n0

j = 0
i = 0
while i < (len(tau)-1):
    krB = 1j*rh*Ep[j,i]*np.conj(Es[j,i])*n[j,i] - Ro[j,i]/T2
    knB = rh*((Ep[j,i])*np.conj(Es[j,i])*np.conj(Ro[j,i])).imag - (n[j,i]-n0)/T1
    brB = 1j*rh*Ep[j,i]*np.conj(Es[j,i])*(n[j,i] + ht/2*knB) \
    - (Ro[j,i] + ht/2*krB)/T2 #differential of Ro (polarization)
    bnB = rh*((((Ep[j,i])*np.conj(Es[j,i]))*np.conj((Ro[j,i] +ht/2*krB))).imag) \
    - ((n[j,i] + ht/2*knB)- n0)/T1 #differential of n (inversion)
    Ro[j,i+1] = Ro[j,i] + ht*brB
    n[j,i+1] = n[j,i] + ht*bnB
    i = i + 1
    
#----------------solving Maxwell with midpoint----------------

k = 1
while k < (len(z)-1): #k goes from 1 and up to and including index 399
    EpE = Ep
    EsE = Es
    
        
    kpE = B1*np.conj((Ro[k-1]))*EsE[k-1] - alphap*EpE[k-1]
    ksE = B2*Ro[k-1]*EpE[k-1] - alphas*EsE[k-1]
    bpE = B1*np.conj(Ro[k-1])*(EsE[k-1] + hz/2*kpE) - alphas*(EsE[k-1] +\
    hz/2*ksE)
    bsE = B2*Ro[k-1]*(EpE[k-1] + hz/2*kpE) - alphas*(EsE[k-1] + hz/2*ksE)
    EpE[k] = EpE[k-1] + hz*bpE
    EsE[k] = EsE[k-1] + hz*bsE
    
#----------------calculating RoE & nE----------------
    RoE = Ro
    nE = n

    i = 0
    while i < (len(tau)-1):
        krBE = 1j*rh*EpE[k,i]*np.conj(EsE[k,i])*nE[k,i] - RoE[k,i]/T2
        knBE = rh*((EpE[k,i]*np.conj(EsE[k,i])*np.conj(RoE[k,i])).imag) - (nE[k,i]-n0)/T1
        brBE = 1j*rh*EpE[k,i]*np.conj(EsE[k,i])*(nE[k,i] + ht/2*knBE) - \
        (RoE[k,i] + ht/2*krBE)/T2
        bnBE = rh*((EpE[k,i]*np.conj(RoE[k,i] + ht/2*krBE)).imag) - \
        ((nE[k,i] + ht/2*knBE)-n0)/T1
        RoE[k,i+1] = RoE[k,i] + ht*brBE
        nE[k,i+1] = nE[k,i] + ht*bnBE
        i = i + 1
    
#----------------averaging Ro and n----------------
    RoA = Ro
    RoA[k,:] = (RoE[k,:]+Ro[k-1,:])/2
    
    nA = n
    nA[k,:] = (nE[k,:]+n[k-1,:])/2
    
    EpM = Ep #Ep field for this step
    EsM = Es #Es field for this step
    
    kpM = B1*np.conj(RoA[k-1,:])*EsM[k-1,:] - alphap*EpM[k-1,:]
    ksM = B2*RoA[k-1,:]*EpM[k-1,:] - alphas*EsM[k-1,:]
    bpM = B1*np.conj(RoA[k-1,:])*(EsM[k-1,:] +hz/2*ksM) - \
    alphap*(EpM[k-1,:] + hz/2*kpM)
    bsM = B2*(RoA[k-1,:])*(EpM[k-1,:] + hz/2*kpM) - \
    alphas*(EsM[k-1,:] + hz/2*ksM)
    EpM[k,:] = EpM[k-1,:] + hz*bpM
    EsM[k,:] = EsM[k-1,:] + hz*bsM
    
    RoM = Ro #final Ro
    nM  = n #final n
    
    i = 0
    while i < (len(tau)-1):
        krBM = 1j*rh*EpM[k,i]*np.conj(EsM[k,i])*nM[k,i] - RoM[k,i]/T2
        knBM = rh*((EpM[k,i]*np.conj(EsM[k,i])*np.conj(RoM[k,i])).imag) -\
        (nM[k,i]-n0)/T1
        brBM = 1j*rh*EpM[k,i]*np.conj(EsE[k,i])*(nM[k,i] + (ht/2)*knBM) -\
        (RoM[k,i] + (ht/2)*krBM)/T2
        bnBM = rh*((EpM[k,i]*np.conj(EsE[k,i])*np.conj(RoM[k,i] + (ht/2)*krBM)).imag) \
        - (nM[k,i] + (ht/2)*knBM - n0)/T1
        RoM[k,i+1] = RoM[k,i] + ht*brBM
        nM[k,i+1] = nM[k,i] + ht*bnBM
        i = i + 1

    Ep[k,:] = EpM[k,:] 
    Es[k,:] = EsM[k,:]
    Ro[k,:] = RoM[k,:]
    n[k,:] = nM[k,:]

    k = k + 1

Ip = Ep**2
Is = Es**2
Itot = Ip + Is

g = np.zeros(len(z),dtype=np.complex_)
j = 0
while j < len(z):
    g[j] = np.trapz(Itot[j,:],t)
    j = j + 1
    
gtot = np.zeros(len(t),dtype=np.complex_)
gp = np.zeros(len(t),dtype=np.complex_)
gs = np.zeros(len(t),dtype=np.complex_)
gr = np.zeros(len(t),dtype=np.complex_)
j = 0
while j < len(z):
    gtot[j] = np.trapz(Itot[j,:],t)
    gp[j] = np.trapz(Ip[j,:],t)
    gs[j] = np.trapz(Is[j,:],t)
    gr[j] = np.trapz(np.abs(Ro[j,:]),t)
    j = j + 1

Itot_norm = (Itot/np.amax(Itot)).real
Ip_norm = (Ip/np.amax(Itot)).real
Is_norm = (Is/np.amax(Itot)).real

gtot_norm = np.abs(gtot/np.amax(gtot))
gp_norm = np.abs(gp/np.amax(gtot))
gs_norm = np.abs(gs/np.amax(gtot))
gr_norm = np.abs(gr/np.amax(gr))

NP = n + 1 #why use this?

fig = plt.figure(1,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,Ip_norm, rstride=1, cstride=1, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False, alpha=0.25)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.set_zlabel('Normalized intensity')
ax.set_zlim3d(0,1.1)
ax.view_init(30, -20)
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()

fig = plt.figure(2,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,Is_norm, rstride=4, cstride=4, alpha=0.25)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.set_zlabel('Normalized intensity')
ax.set_zlim3d(0,1.1)
ax.view_init(25, -20)
plt.show()

fig = plt.figure(3,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,(np.abs(Ro)).real, rstride=4, cstride=4, alpha=0.25)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.view_init(30, -30)
plt.show()

fig = plt.figure(4,figsize=(4,4))
X, Y = np.meshgrid(t,z)
ax = fig.add_subplot(1,1,1, projection='3d')
surf = ax.plot_surface(X,Y,NP.real, rstride=4, cstride=4, alpha=0.25)
ax.set_xlabel('t (ns)')
ax.set_ylabel('z')
ax.view_init(25, -70)
plt.show()

plt.figure(5)
plt.subplot(221)
plt.plot(t,Itot_norm[100,:],'k:',linewidth = 1.5, label = 'Total intensity')
plt.plot(t,Ip_norm[100,:],'b',linewidth = 1.5, label = 'Pump intensity')
plt.plot(t,Is_norm[100,:],'r',linewidth = 1.5, label = 'Stokes intensity')
plt.title('z = 0.5') #Pump and Stokes Intensity at z = 0.5 
#plt.xlabel('t (ns)')
plt.ylabel('Normalized intensity')
plt.xlim([np.amin(t),np.amax(t)])
plt.ylim([0,np.amax(Itot_norm)+0.1])
#plt.legend()

plt.subplot(222)
plt.plot(t,Itot_norm[200,:],'k:',linewidth = 1.5, label = 'Total intensity')
plt.plot(t,Ip_norm[200,:],'b',linewidth = 1.5, label = 'Pump intensity')
plt.plot(t,Is_norm[200,:],'r',linewidth = 1.5, label = 'Stokes intensity')
plt.title('z = 1.0')
#plt.xlabel('t (ns)')
#plt.ylabel('Normalized intensity')
plt.xlim([np.amin(t),np.amax(t)])
plt.ylim([0,np.amax(Itot_norm)+0.1])
#plt.legend()

plt.subplot(223)
plt.plot(t,Itot_norm[300,:],'k:',linewidth = 1.5, label = 'Total intensity')
plt.plot(t,Ip_norm[300,:],'b',linewidth = 1.5, label = 'Pump intensity')
plt.plot(t,Is_norm[300,:],'r',linewidth = 1.5, label = 'Stokes intensity')
plt.title('z = 1.5')
plt.xlabel('t (ns)')
plt.ylabel('Normalized intensity')
plt.xlim([np.amin(t),np.amax(t)])
plt.ylim([0,np.amax(Itot_norm)+0.1])
#plt.legend()

plt.subplot(224)
plt.plot(t,Itot_norm[399,:],'k:',linewidth = 1.5, label = 'Total intensity')
plt.plot(t,Ip_norm[399,:],'b',linewidth = 1.5, label = 'Pump intensity')
plt.plot(t,Is_norm[399,:],'r',linewidth = 1.5, label = 'Stokes intensity')
plt.title('z = 2.0')
plt.xlabel('t (ns)')
#plt.ylabel('Normalized intensity')
plt.xlim([np.amin(t),np.amax(t)])
plt.ylim([0,np.amax(Itot_norm)+0.1])
#plt.legend()

plt.figure(6)
plt.plot(z,gtot_norm,'k:',linewidth = 1.5, label = 'Total energy')
plt.plot(z,gp_norm,'b',linewidth = 1.5, label = 'Pump')
plt.plot(z,gs_norm,'r',linewidth = 1.5, label = 'Stokes')
plt.title('Energy in the system')
plt.xlabel('z')
plt.ylabel('Normalized energy')
plt.xlim([np.amin(z),np.amax(z)])
plt.ylim([0,np.amax(gtot_norm)+0.1])
plt.legend()
plt.show()