#! /usr/bin/python3

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import statistics as stat

# matplotlib stuff
mpl.rc('legend', fontsize=10)
mpl.rc('keymap', quit='q')
mpl.rc('axes', grid='True')
def onresize(event):
	plt.tight_layout()


def f_inv(x, a, b, c):
	return a / (x + b) + c

def hypfit(x, y):
	# This function does linear regression on
	#         a
	#   y = ----- + c
	#       x + b
	# 
	# which can be transformed into
	#   xy = a + bc - by + cx
	# 
	# setting a' = a + bc:
	#   xy = a' - by + cx
	# 
	# where x and y are data vectors. This leads to a set of linear equations:
	#   x1 * y1 = a' - b * y1 + c * x1 + eps1
	#   x2 * y2 = a' - b * y2 + c * x2 + eps2
	#   ...
	# 
	# expressed in matrix notation:
	#   [[ x1 * y1 ]  = [[ 1 -y1 x1 ]    [[ a' ]    [[ eps1 ]
	#    [ x2 * y2 ]  =  [ 1 -y2 x2 ]  *  [ b  ]  +  [ eps2 ]
	#    [ x3 * y3 ]  =  [ 1 -y3 x3 ]     [ c  ]]    [ eps3 ]
	#    [ x4 * y4 ]  =  [ 1 -y4 x4 ]                [ eps4 ]
	#    [   ...   ]] =  [   ...    ]]               [ ...  ]]
	#
	#        xy       =       F        *    a     +    eps
	#
	# Minimizing the eps vector requires:
	#   a_estimated = (FtF)^-1 Ft xy

	xy  = np.array([x * y for (x, y) in zip(x, y)])
	F   = np.array([[1, -y, x] for (x, y) in zip(x, y)])
	est = np.linalg.inv(F.T.dot(F)).dot(F.T).dot(xy)
	return [est[0] - est[1] * est[2], est[1], est[2]]

# use realistic values for a, b and c
a = 3000
b = 1
c = 107
print('actual: a', a, 'b', b, 'c', c)

# data set
# x = np.array([5, 10, 15])
# x = np.array(range(5, 16))
x = np.array(np.linspace(15, 5, 1000))
y = np.array([f_inv(i, a, b, c) for i in x])
sd = 10
y += sd * np.random.randn(*y.shape)
# print(x)
# print(y)

# now estimate the coefficients
start = timer()
est_abc = hypfit(x, y)
end = timer()
print('fit:    a %.2f b %.2f c %.2f duration %.2e ms' % (est_abc[0], est_abc[1], est_abc[2], (end - start)/1000))

# quality of fit
residual = [y - f_inv(x, *est_abc) for (x, y) in zip(x, y)]
print('residual min %.2f max %.2f mean %.2f stddev %.2f' % (min(residual), max(residual), stat.mean(residual), stat.pstdev(residual)))

# and plot the results
fig = plt.figure()
fig.canvas.mpl_connect('resize_event', onresize)

ax1 = plt.subplot(2, 1, 1)
plt.plot(x, [f_inv(i, a, b, c) for i in x], label='orig')
# plt.plot(x, y, '.', markersize=10, label='data')	# only for a small number of samples
plt.plot(x, y, label='data ($\sigma$='+repr(sd)+')')
plt.plot(x, [f_inv(i, *est_abc) for i in x], label='fit')
plt.legend(loc=2)
plt.xlim(x[0] + 1, x[-1] - 1)	# reverse X axis

ax2 = plt.subplot(2, 1, 2, sharex=ax1)
plt.plot(x, residual, label='residual')
plt.legend(loc=2)

plt.tight_layout()
plt.show()
