#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  2 12:35:16 2021

@author: patrickspitz
"""

#%% Usual imports
import numpy as np
import matplotlib.pyplot as plt

#%% Exercise 1
# Implement euler

def forward_euler( F, x_0, n, T):
    """
    This function tries to solve the 1st order ODE x'(t) = F(t, x(t)) using forward euler
    where x might be R^n-valued

    Parameters
    ----------
    F : Function with 2 parameters
        Describes the right-hand side of the ODE
    x_0 : float/double or 1-dim np.array
        Initial value
    n : int
        Number of iteration steps to compute, one will yield a matrix with n+1 columns
    T : float/double/int
        Up to which time euler will compute

    Returns
    -------
    data : (dim, n+1)-np.matrix
        Every columns describes x(t_j) where the j-th column corresponds to t_j = j*T/n

    """
    # first check whether x_0 is in \R or \R^n, n > 1, set n accordingly
    if type(x_0) != np.ndarray:
        dim = 1
    else:
        dim = len(x_0)
    # ask the OS for enough storage to store the complete approximation of the solution
    data = np.zeros((dim, n+1))
    # copy over initial data
    data[:, 0] = x_0
    # calculate h
    h = T/n
    for i in range(n):
        # apply iteration rule
        data[:,i+1] = data[:, i] + h*F(i*T/n, data[:, i])
    # and return everything
    return data

#%% Exercises 2 & 3
# Use euler to solve SIR with beta_0 = 0.5, gamma = 0.15 and k = 0.3 resp. k = 0.6
# T and n up to your choice
# Try different initial values

# first try to assemble the F for SIR, depending on k
# create a small helper function to deal with different k
def build_SIR_right_side( beta_0, gamma, k):
    """
    Create a lambda expression that corresponds to the right-hand side of the SIR-ODE suing the given parameters.

    Parameters
    ----------
    beta_0 : float/double
        The infectiousness of the disease without countermeasures, see SIR-model for details
    gamma : float/double
        The inverse of the average duration
    k : float/double
        Effect of counter-measures, ranges from 0(=no effect) to 1(=no new infections)

    Returns
    -------
    F : Function that accepts anything as first and a np.array(2) as second parameters
        Describes the right-hand side of the SIR-ODE

    """
    # first get effective beta
    beta = beta_0 * (1 - k)
    # build F
    F = lambda t,x: np.array([- beta*x[0]*x[1], beta*x[0]*x[1] - gamma*x[1]])
    # and return it
    return F

# Now we want to solve the ODE with different parameters
# As we are doing exactly the same thing at least 4 times it is wise to use another function
def plot_SIR( beta_0, gamma, k, T, s_0, subplotIndex):
    """
    Solve the SIR-model for the given parameters and initial values (s_0, 1-s_0) up to time T.
    The resulting data will be plotted into a 2x2-matrix of plots where subplotIndex is the index of the subplot.

    Parameters
    ----------
    beta_0 : float
        Infectiousness of the disease
    gamma : float
        Inverse of average duration of disease
    k : float, 0 <= k <= 1
        Describes the effect of measures onto infectiousness, 0 = no effect, 1 = no new infections
    T : float
        Up to this time the SIR-ODE will be solved
    s_0 : float, 0 <= s_0 <= 1
        Share of population that is susceptible at t=0 
    subplotIndex : int, 1 to 4
        Index of the subplot where the resulting plot will be drawn.

    Returns
    -------
    None.

    """
    # build right-hand side
    F = build_SIR_right_side( beta_0, gamma, k)
    # hardcode n for now
    n = 10**4
    # get solution
    sol = forward_euler(F, np.array([s_0, 1-s_0]), n, T)
    
    # plot solution
    # get new subplot
    plt.subplot(2,2,subplotIndex)
    # build data for x-axis
    myTime = np.linspace( 0, T, num = n+1)
    # actual plots
    plt.plot( myTime, sol[0,:], 'b', label = 'Sus')
    plt.plot( myTime, sol[1,:], 'r', label = 'Inf')
    # now necessary beautifications
    plt.xlabel('Time in days')
    plt.ylabel('Share of population')
    plt.title('SIR-model for beta = ' + str(beta_0*(1-k)) + ' and gamma = '+str(gamma))
    plt.xlim([0,T])
    plt.ylim([0,1])
    plt.grid()
    plt.legend()
    
    # one wants to confirm theorem 1.1
    # get max number of infected individuals
    maxInf = np.max(sol[1,:])
    sigma_0 = beta_0 * (1-k)/gamma
    if sigma_0 * s_0 <= 1:
        print('Prediction: i is monotonically decreasing.')
        theoInfMax = 1 - s_0
    else:
        # use s_0 + i_0 = 1 here
        theoInfMax = 1 - (1 + np.log(sigma_0*s_0))/sigma_0
        print('Expected max number of infections: '+str(theoInfMax))
    print('Actual max: '+str(maxInf))
    
    # the threshold 10**(-4) is kinda arbitrary, change if you like so
    if abs(theoInfMax - maxInf) < 10**(-4):
        print('Theorem 1.1 seems to work out.')
    else:
        print('This might be a counter example to theorem 1.1.')
        print('Please check again with a more detailed computation.')
    print('')

# now we can solve the exercise at hand
# on my machine the default figure size results in overlapping texts
# a bigger figure size avoids this
scaleFactor = 1.7
plt.figure(figsize = (8*scaleFactor,6*scaleFactor))
# run different scenarios
plot_SIR( 0.5, 0.15, 0.3, 100, 0.99, 1)
plot_SIR( 0.5, 0.15, 0.3, 50, 0.50, 2)
plot_SIR( 0.5, 0.15, 0.6, 200, 0.99, 3)
plot_SIR( 0.5, 0.15, 0.6, 50, 0.50, 4)
# and save it to file
# instead of pdf any other vector graphics format is fine aswell
# raster graphics dont play well if you want to change the size of your image
# plt.show()
plt.savefig('./SIRPlots.pdf')

#%% Exercise 4

# Write a function that reads from some text file and that returns the contents as dictionaries
def read_file( fileName):
    """
    Read a file with the format specified in exercise 4 and return its contents as a dictionary.

    Parameters
    ----------
    fileName : string
        Path to file. Your program will crash if this is invalid.

    Returns
    -------
    myDict : Dictionary
        Contains the contents of the file

    """
    # one could argue that a "try except FileNotFoundError"-construction might be useful here
    # The author doesnt think that way. If a file name is invalid one can assume that the program
    # will fail somewhere down the line anyways due to relying onto a nonexistant file
    stream = open(fileName, 'r')
    newLine = stream.readline()
    myDict = {}
    while newLine != '':
        separatedBySpaces = newLine.split(' ')
        myDict[separatedBySpaces[0]] = float(separatedBySpaces[1])
        newLine = stream.readline()
    stream.close()
    return myDict

# This thing is case-sensitive
param = read_file('./Ressources/Param.txt')
test_1 = read_file('./Ressources/Test1.txt')
test_2 = read_file('./Ressources/Test2.txt')

#%% Exercise 5

# hardcode some stuff because lazy
n = 10**5
T = 250

# Solve SIR-model with test1
newF = build_SIR_right_side( param['beta_0'], param['gamma'], test_1['k'])
solSIR = forward_euler( newF, np.array([test_1['s_0'], test_1['i_0']]), n, T)

# Prep stuff to deal with vSEIR more easily
def build_vSEIR_right_side( beta_0, k, kappa, gamma, vrel):
    """
    Build the right hand side of the vSEIR ODE

    Parameters
    ----------
    beta_0 : float
        Infectiousness without counter measures
    k : function that returns a float in [0,1]
        Effectiveness of counter measures, 0 = no effect, 1 = no new infections
    kappa : float
        Inverse of length of latent period
    gamma : float
        Inverse of average length of infectiousness
    vrel : Function that returns float
        This function describes the available vaccines per day

    Returns
    -------
    F : Function R times R^3 to R^3
        The desired right hand side of the ODE

    """
    effectiveBeta = lambda t: beta_0 * (1 - k(t))
    F = lambda t,x: np.array([ -effectiveBeta(t)*x[0]*x[2] - min( x[0], vrel(t)),
                               effectiveBeta(t)*x[0]*x[2] - kappa*x[1],
                               kappa*x[1] - gamma*x[2] ])
    return F

                             
# Solve vSEIR-model with test1
vSEIRF1 = build_vSEIR_right_side( param['beta_0'], lambda t: test_1['k'], param['kappa'], param['gamma'], lambda t: test_1['v_rel'])
solVSEIR1 = forward_euler( vSEIRF1, np.array([test_1['s_0'], test_1['e_0'], test_1['i_0']]), n, T)

# Same for test_2
vSEIRF2 = build_vSEIR_right_side( param['beta_0'], lambda t: test_2['k'], param['kappa'], param['gamma'], lambda t: test_2['v_rel'])
solVSEIR2 = forward_euler( vSEIRF2, np.array([test_2['s_0'], test_2['e_0'], test_2['i_0']]), n, T)

# Build a third setting test3 with nonconstant k and vrel and solve it
e_0 = 10**(-5)
x0 = np.array([1-e_0, e_0, 0])
# k is a sine now
k = lambda t: (1 + np.sin(0.05*t  * 2 * np.pi)) * 0.3
# vrel is some logistic curve
vrel = lambda t: 0.01 / (1 + np.exp(- 0.1*(t - 200)))
# actual computation
vSEIRF3 = build_vSEIR_right_side( param['beta_0'], k, param['kappa'], param['gamma'], vrel)
solVSEIR3 = forward_euler( vSEIRF3, x0, n, T)

# Plot all resulting i in a single graph

# This stuff is always identical so write another function for it.
def make_plot_pretty( title):
    plt.legend()
    plt.xlim([0,T])
    plt.ylim([0,1])
    plt.xlabel('Time in days')
    plt.ylabel('Share of population')
    plt.title( title)
    
    
plt.figure(figsize = (12,9))
timeAxis = np.linspace(0, T, n+1)

plt.subplot(2,2,1)
plt.plot( timeAxis, solSIR[1,:], label = 'Inf')
make_plot_pretty( 'SIR model')

plt.subplot(2,2,2)
plt.plot( timeAxis, solVSEIR1[2,:], label = 'Inf')
make_plot_pretty( 'vSEIR model(test 1)')

plt.subplot(2,2,3)
plt.plot( timeAxis, solVSEIR2[2,:], label = 'Inf')
make_plot_pretty( 'vSEIR model(test 2)')

plt.subplot(2,2,4)
plt.plot( timeAxis, solVSEIR3[2,:], label = 'Inf')
plt.plot( timeAxis, solVSEIR3[0,:], label = 'Sus')
plt.plot( timeAxis, k(timeAxis), label = 'k(t)')
make_plot_pretty( 'vSEIR model(test 3)')

plt.savefig('ScenariosSubPlot.pdf')


# Compare behaviour
# for this plot all i's in a single plot
plt.figure()
plt.plot( timeAxis, solSIR[1,:], label = 'SIR(test1)')
plt.plot( timeAxis, solVSEIR1[2,:], label = 'vSEIR(test1)')
plt.plot( timeAxis, solVSEIR2[2,:], label = 'vSEIR(test2)')
plt.plot( timeAxis, solVSEIR3[2,:], label = 'vSEIR(test3)')
make_plot_pretty('i(t) for all scenarios')
plt.ylim([0, 0.3])

plt.show()
plt.savefig('Scenarios_i.pdf')
# Effect of e?
# e delays the course of the pandemic
# How do k and v_rel change dynamic?
# Higher k -> Less rapid growth
# v_rel kills the virus for good

#%% Exercise 6
# Implement algo 3 and experiment with it
def total_infected( sol, kappa, N, T, t):
    """
    Calculate the total number of infected individuals upto time t.

    Parameters
    ----------
    sol : np.matrix
        Contains the solution to the vSEIR-model as returned from forward_euler.
    kappa : float
        Parameter describing the latent period.
    N : int
        Size of population.
    T : float
        Upto this time euler solved the model.
    t : float
        The time which is interesting for the user as described in the short summary.

    Returns
    -------
    TYPE
        DESCRIPTION.

    """
    n = sol.shape[1] - 1
    i_total = sol[2,0]
    h = T / n
    k = 1
    while h*k < t:
        i_total += h*kappa*sol[1,k-1]
        k += 1
    return i_total*N

tau = 150
# for test 1
infected1 = total_infected( solVSEIR1, param['kappa'], param['N'], T, tau)

def get_infected_trivial( tau, sol):
    """
    Calculate the number of infected people upto time t using N - s(t) - e(t).

    Parameters
    ----------
    tau : float
        The time one is interested in.
    sol : np.matrix
        Solution to vSEIR as returned by forward_euler.

    Returns
    -------
    float
        Number of infected persons upto time tau.

    """
    correspIndex = int(tau/T * n)
    return param['N']*(1 - solVSEIR1[0, correspIndex] - solVSEIR1[1, correspIndex])

naiveInfected1 = get_infected_trivial(tau, solVSEIR1)
print('Setting: Test 1')
print('By i(tau) + r(tau): '+str(naiveInfected1))
print('By total_infected: '+str(infected1))
print('')
# The values coincide because there are no vaccines in test_1

# Same calculations for test_2 and test_3
infected2 = total_infected( solVSEIR2, param['kappa'], param['N'], T, tau)
naiveInfected2 = get_infected_trivial(tau, solVSEIR2)
print('Setting: Test 2')
print('By i(tau) + r(tau): '+str(naiveInfected2))
print('By total_infected: '+str(infected2))
print('')

infected3 = total_infected( solVSEIR3, param['kappa'], param['N'], T, tau)
naiveInfected3 = get_infected_trivial(tau, solVSEIR3)
print('Setting: Test 3')
print('By i(tau) + r(tau): '+str(naiveInfected3))
print('By total_infected: '+str(infected3))
print('')

# Results do not coincide, reason being vaccines

#%% Exercise 7

# Consider settings test_2 and test_3
# Calculate NInf for every single day, save it to file totalInfN.txt

# Here we will implement a 'bad' solution due to lazyess
# Ideally one wants to compute only the difference of infections between 2 days
# This would require to rewrite 3-4 additional lines of code which is way too much effort

file1 = open('./totalInf1.txt', 'w')
file2 = open('./totalInf2.txt', 'w')
file1.write('Day NInf \n')
file2.write('Day NInf \n')
for t in range(T):
    inf2 = total_infected( solVSEIR2, param['kappa'], param['N'], T, t+1)
    inf3 = total_infected( solVSEIR3, param['kappa'], param['N'], T, t+1)
    file1.write(str(t+1) + ' ' + str(inf2) + '\n')
    file2.write(str(t+1) + ' ' + str(inf3) + '\n')
    print('Done with '+str(t+1))

file1.close()
file2.close()

#%% Exercise 8
# Phase portraits for all 3 settings
# Plot trajectories in s-i-plot
# May use DrawArrows.py
# Interpret it
from DrawArrows import draw_arrow

# build_vSEIR_right_side( beta_0, k, kappa, gamma, vrel):

def generate_phase_portrait( F):
    """
    Generate a phase portrait using vSEIR-model and the right-hand side provided

    Parameters
    ----------
    F : Function R times R^3 to R^3
        Right-hand side of vSEIR ODE

    Returns
    -------
    None.

    """
    # Pick 10 different starting points with s_0 + i_0 = 1
    myInitialS = (1+np.arange(10))/11
    # again hardcode n and T because lazy
    n = 10**4
    T = 400
    # new figure
    plt.figure()
    for s in myInitialS:
        x0 = np.array([s, 0, 1-s])
        sol = forward_euler( F, x0, n, T)
        line = plt.plot( sol[0,:], sol[2,:], 'b--')[0]
        draw_arrow(line, position = 0.075)
    plt.xlabel('s')
    plt.ylabel('i')
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.grid()
    startingPoints = np.linspace(0,1,500)
    plt.plot( 1-startingPoints, startingPoints, 'r-')

F1 = build_vSEIR_right_side( param['beta_0'], lambda t: test_1['k'], param['kappa'], param['gamma'], lambda t: test_1['v_rel'])
generate_phase_portrait( F1)
plt.title('Phase portrait for test 1')
plt.savefig('PhaseDiag1.pdf')

F2 = build_vSEIR_right_side( param['beta_0'], lambda t: test_2['k'], param['kappa'], param['gamma'], lambda t: test_2['v_rel'])
generate_phase_portrait( F2)
plt.title('Phase portrait for test 2')
plt.savefig('PhaseDiag2.pdf')

F3 = build_vSEIR_right_side( param['beta_0'], k, param['kappa'], param['gamma'], vrel)
generate_phase_portrait( F3)
plt.title('Phase portrait for test 3')
plt.savefig('PhaseDiag2.pdf')


#%% Exercise 9
# Simple lockdown strategy

def adaptive_euler(k_a, k_b, i_a, i_b, x_0, beta_0, gamma, kappa, v_rel, T, n):
    # prepare variables
    lockdown = False
    lockdownStates = np.zeros(n+1) # 0 for noLockdown, 1 for lockdown
    lockdownStates[0] = 0
    sol = np.zeros((3, n+1))
    F_lockdown = build_vSEIR_right_side(beta_0, lambda t: k_b, kappa, gamma, v_rel)
    F_no_lock = build_vSEIR_right_side(beta_0, lambda t: k_a, kappa, gamma, v_rel)
    h = T/n
    # set initial value
    sol[:,0] = x_0
    for i in range(n):
        # pick matching F
        if lockdown:
            F = F_lockdown
            lockdownStates[i+1] = 1
        else:
            F = F_no_lock
            lockdownStates[i+1] = 0
        # euler iteration
        sol[:, i+1] = sol[:,i] + h*F(h*i, sol[:,i])
        # update lockdown status
        if lockdown and sol[2,i+1] < i_a:
            lockdown = False
        if (not lockdown) and sol[2,i+1] >= i_b:
            lockdown = True
    return sol, lockdownStates

# some vars
T = 350
n = 10**5
i_a = 0.0001
i_b = 0.0003


# Scenario 1
x0 = np.array([test_1['s_0'],  test_1['e_0'], test_1['i_0']])
sol, ldStates = adaptive_euler( 0.4, 0.8, i_a, i_b, x0, param['beta_0'], param['gamma'], param['kappa'], lambda t: test_1['v_rel'], T, n)
timeAxis = np.linspace(0, T, n+1) 

plt.figure()
plt.plot( timeAxis, sol[2,:], 'r', label = 'i(t)')
plt.plot( timeAxis, i_a + (i_b - i_a)*ldStates, 'b', label = 'lockdown state')
plt.xlim([0,T])
plt.ylim([0, .015])
plt.ylabel('Share of population')
plt.xlabel('Time in days')
plt.title('i(t) with lockdown strategy(test 1)')
plt.legend()
plt.savefig('Lockdown1.pdf')

# Scenario 2
x0 = np.array([test_2['s_0'],  test_2['e_0'], test_2['i_0']])
sol, ldStates = adaptive_euler( 0.4, 0.8, i_a, i_b, x0, param['beta_0'], param['gamma'], param['kappa'], lambda t: test_2['v_rel'], T, n)

plt.figure()
plt.plot( timeAxis, sol[2,:], 'r', label = 'i(t)')
plt.plot( timeAxis, i_a + (i_b - i_a)*ldStates, 'b', label = 'lockdown state')
plt.xlim([0,T])
plt.ylim([0, .015])
plt.ylabel('Share of population')
plt.xlabel('Time in days')
plt.legend()
plt.title('i(t) with lockdown strategy(test 2)')
plt.savefig('Lockdown2.pdf')

# Scenario 3
x0 = np.array([test_1['s_0'],  test_1['e_0'], test_1['i_0']])
sol, ldStates = adaptive_euler( 0.4, 0.8, i_a, i_b, x0, param['beta_0'], param['gamma'], param['kappa'], vrel, T, n) 

plt.figure()
plt.plot( timeAxis, sol[2,:], 'r', label = 'i(t)')
plt.plot( timeAxis, i_a + (i_b - i_a)*ldStates, 'b', label = 'lockdown state')
plt.xlim([0,T])
plt.ylim([0, .015])
plt.ylabel('Share of population')
plt.xlabel('Time in days')
plt.legend()
plt.title('i(t) with lockdown strategy(test 3)')
plt.savefig('Lockdown3.pdf')
