#!/usr/bin/env python
# coding: utf-8

# In[95]:


### This script calculates the band structure of Graphene Nanoribbons###


# In[96]:


# Import Libraries

import numpy as np
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
from cmath import exp
from math import pi


# In[97]:


# Input Data:

a = 1.42
N = 7
nunit = 1
Ntot = 2*N*nunit
t = -2.7


# In[98]:


# Preallocation:

PosX = np.zeros([Ntot,1])
PosY = np.zeros([Ntot,1])
H01  = np.zeros([2*N,2*N])
Kpoints   = np.linspace(-pi/(3*a),pi/(3*a),num = 50, endpoint=True)
Ev   = np.zeros([2*N,len(Kpoints)])


# In[99]:


# Finding Points:

for  m in range(nunit):
        for n in range(2):
                for i in range(N):
                        if i%2==0:
                            PosX[i+n*N+m*2*N]=a+n*a+m*3*a
                        elif i%2==1:
                            PosX[i+n*N+m*2*N]=(a/2.)+n*2*a+m*3*a
                            
                for i in range(N):
                            PosY[i+n*N+m*2*N]=i*(3**0.5)*(a/2.)
                        
print('PosX= ',PosX)
print()
print('PosY= ',PosY)


# In[100]:


PosXY = np.concatenate((PosX, PosY), axis=1)


# In[101]:


print(PosXY)


# In[102]:


# Calculating Distance between points:

dist=cdist(PosXY,PosXY,'euclidean')


# In[103]:


print(dist)


#  print(dist[13,6])

# In[104]:


dist.shape


# In[105]:


# Calculating Hamiltonian:


#H00= t*((dist >= 1.41) & (dist <= 1.43)).astype(int)
# Alternative:

H00 = np.where(((dist >= 1.41) & (dist <= 1.43)), t , 0)


# In[106]:


print('H00 = ', H00)


# In[107]:


# Calculating Hopping Matrix:

for i in range(N):
    if i%2==1:
        H01[i,i+N]=t


# In[108]:


print('H01 = ', H01)

# up to here I am sure that the code is correct#

# In[109]:


# Calculating Band Structure

ii=0
for kx in range (len(Kpoints)-1):
    ii=ii+1
    H_tot=H00+H01*exp(1j*kx*3*a)+H01.T*exp(-1j*kx*3*a)
    Ev[:,ii]=np.linalg.eigvals(H_tot)
    Ev=np.sort(Ev)


# In[110]:


print(Ev)
print(Ev.shape)


# In[120]:


plt.figure()

for jj in range(2*N):
     plt.plot(Kpoints,Ev[jj])
plt.show()


# In[ ]:




