import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import *
from sympy import *
import copy

def Time( radius, mass ):
    return simplify( 1 - ( 2.0 * mass ) / radius )

def Radius( radius, mass ):
    return simplify( Time( radius, mass ) ** ( -1 ) )

def Theta( radius ):
    return ( radius ** 2 )

def Phi( radius, theta ):
    return ( ( radius **2 ) * ( sin( theta ) ** 2 ) )

def ShwarzchildMetric( mass, radius, theta ):
    return [ [ -Time( radius, mass ), 0.0, 0.0, 0.0 ], 
               [ 0.0, Radius( radius, mass ), 0.0, 0.0 ], 
               [ 0.0, 0.0, Theta( radius ), 0.0 ], 
               [ 0.0, 0.0, 0.0, Phi( radius, theta ) ] ]

time = Symbol( 'time' )
radius = Symbol( 'radius' ) 
theta = Symbol( 'theta' )
phi = Symbol( 'phi' )

symboles = [ time, radius, theta, phi ]
display = [ 't', 'r', 'theta', 'phi' ]#[ u't', u'r', u'u0398', u'u03D5' ]

def CoordinateString( a, b, c ):
    return display[ a ] + u', ' + display[ b ] + u', ' + display[ c ] + u': '

def DebugPrint( toPrint, debug ):
    if debug == True:
        if isinstance( toPrint, tuple ):
            print( *toPrint )
        else:
            print( toPrint )


def DeriveMetricParameter( metric, x, y, withRespectTo ):
    if( type( metric[ x ][ y ] ) != float ):
        return metric[ x ][ y ].diff( withRespectTo )
    return 0.0

def Connection( mass, radius, theta, a, b, c, symboles, debug = True ):
    global DeriveMetricParameter
    Derive = DeriveMetricParameter
    metric = ShwarzchildMetric( mass, radius, theta )
    result = [ [ 0.0, 0.0, 0.0, 0.0 ], 
                [ 0.0, 0.0, 0.0, 0.0 ], 
                [ 0.0, 0.0, 0.0, 0.0 ], 
                [ 0.0, 0.0, 0.0, 0.0 ] ]
    results = []
    results.append( copy.deepcopy( result ) )
    results.append( copy.deepcopy( result ) )
    results.append( copy.deepcopy( result ) )
    results.append( copy.deepcopy( result ) )
    for l in range( 4 ):
        DebugPrint( '<matrix' + str( l ) + '>', debug )
        for m in range( 4 ):
            DebugPrint( '\t<row' + str( m ) + '>', debug )
            for n in range( 4 ):
                DebugPrint( '\t\t<number' + str( m ) + 'x' + str( n ) + '>', debug )
                for p in range ( 4 ):
                    #print( "Deriving0: ", symboles[ n ], ": ", metric[ m ][ p ], "::", Derive( metric, m, p, symboles[ n ] ) )
                    #print( "Deriving1: ", symboles[ m ], ": ", metric[ n ][ p ], "::", Derive( metric, n, p, symboles[ m ] ) )
                    #print( "Deriving2: ", symboles[ p ], ": ",  metric[ m ][ n ], "::",Derive( metric, m, n, symboles[ p ] ) )
                    DebugPrint( ( '\t\t\t', 0.5 * metric[ l ][ p ], ' * ( ', 
                                               Derive( metric, m, p, symboles[ n ] ), '+', 
                                               Derive( metric, n, p, symboles[ m ] ), '-', 
                                               Derive( metric, m, n, symboles[ p ] ), ' )', 'l: ', l, ' m: ', m, ' n: ', n, ' p: ', p ), debug )
                    results[ l ][ m ][ n ] += ( 0.5 * metric[ l ][ p ] * 
                                               ( Derive( metric, m, p, symboles[ n ] ) + 
                                               Derive( metric, n, p, symboles[ m ] ) - 
                                               Derive( metric, m, n, symboles[ p ] ) ) )
                DebugPrint( ( '\t\t\tresult: ', results[ l ][ m ][ n ], '\n' ), debug )
                DebugPrint( '\t\t\t<ncurrently>', debug )
                for o in results[ l ]:
                    DebugPrint( ( '\t\t\t\t', o ), debug )
                DebugPrint( '\t\t\t</ncurrently>', debug )
                DebugPrint( '\t\t</number' + str( m ) + 'x' + str( n ) + '>', debug )
            DebugPrint( '\t\t<rcurrently>', debug )
            for o in results[ l ]:
                DebugPrint( ( '\t\t\t', o ), debug )
            DebugPrint( '\t\t</rcurrently>', debug )
            DebugPrint( '\t</row' + str( m ) + '>', debug )
        DebugPrint( '</matrix' + str( l ) + '>\n\n', debug )
    return results

connection = Connection( 1.0, radius, theta, 3, 1, 3, symboles, True )
print( "\n\n" )
for i in connection:
    for j in i:
        print( simplify( j ) )
    print( '\n\n' )

def DeriveCoordinate( metric, x, y, withRespectTo ):
    if( type( metric[ x ][ y ] ) != float ):
        metric[ x ][ y ] += metric[ x ][ y ].diff( withRespectTo )
    else:
        metric[ x ][ y ] += 0.0
    return metric

def IntermediateConnection( mass, radius, theta, a, b, c, symboles ):
    metric0 = ShwarzchildMetric( mass, radius, theta )
    metric1 = ShwarzchildMetric( mass, radius, theta )
    metric2 = ShwarzchildMetric( mass, radius, theta )
    metric3 = ShwarzchildMetric( mass, radius, theta )
    result = 0.0
    for i in range( 4 ):
        metric1 = DeriveCoordinate( metric1, b, i, symboles[ c ] )
        metric2 = DeriveCoordinate( metric2, c, i, symboles[ b ] )
        metric3 = DeriveCoordinate( metric3, b, c, symboles[ i ] )
        result += 0.5 * metric0[ a ][ i ] * ( metric1[ b ][ i ] + metric2[ c ][ i ] - metric3[ b ][ c ] )
    return result

print( '-----Old-----' )

print( CoordinateString( 3, 1, 3 ), IntermediateConnection( 1.0, radius, theta, 3, 1, 3, symboles ) )
print( CoordinateString( 3, 2, 3 ), IntermediateConnection( 1.0, radius, theta, 3, 2, 3, symboles ) )
print( CoordinateString( 2, 2, 2 ), IntermediateConnection( 1.0, radius, theta, 2, 2, 2, symboles ) )
print( CoordinateString( 2, 1, 3 ), IntermediateConnection( 1.0, radius, theta, 2, 1, 3, symboles ) )
#print( CoordinateString( 3, 1, 3 ), IntermediateConnection( 1.0, radius, theta, 3, 1, 3, symboles ) )
#print( CoordinateString( 3, 2, 3 ), IntermediateConnection( 1.0, radius, theta, 3, 2, 3, symboles ) )
print( CoordinateString( 0, 0, 3 ), IntermediateConnection( 1.0, radius, theta, 0, 0, 1, symboles ) )
print( CoordinateString( 1, 0, 0 ), IntermediateConnection( 1.0, radius, theta, 1, 0, 0, symboles ) )
print( CoordinateString( 1, 1, 1 ), IntermediateConnection( 1.0, radius, theta, 1, 1, 1, symboles ) )
print( CoordinateString( 1, 3, 3 ), IntermediateConnection( 1.0, radius, theta, 1, 3, 3, symboles ) )
print( CoordinateString( 1, 2, 2 ), IntermediateConnection( 1.0, radius, theta, 1, 2, 2, symboles ) )