# Monte Carlo valuation of European option
#
import numpy as np
import numpy.random as npr
import math

from app.entity.ConstantValue import ConstantValue
from numpy import cumsum

import warnings
warnings.simplefilter('ignore', np.RankWarning)

class MonteCarloAmerican:

    def getValue(self, optionType, stockPrice, strike, volatility, expiryYears, dividendYield, riskfreeRate):
        S0 = stockPrice
        K = strike
        T = expiryYears
        r = riskfreeRate
        sigma = volatility
        
        I = 1000       # number of paths
        M = 100         # number of time steps
        
        dt = float(T) / float(M)    # length of time interval
        df = np.exp(-r * dt)
        
        # simulation of index levels
        S = np.zeros((M + 1, I))
        S[0] = S0
        sn = self.gen_sn(M, I)
        
        for t in range(1, M + 1):
            S[t] = S[t - 1] * np.exp((r - 0.5 * sigma ** 2) * dt + \
                         sigma * np.sqrt(dt) * sn[t])
        
        # payoff
        if (optionType == ConstantValue.CALL):
            h = np.maximum(S - K, 0)
        else:
            h = np.maximum(K - S, 0)
        
        # LSM algorithm
        V = np.copy(h)
        for t in range(M - 1, 0, -1):
            reg = np.polyfit(S[t], V[t + 1] * df, 7)
            C = np.polyval(reg, S[t])
            V[t] = np.where(C > h[t], V[t + 1] * df, h[t])
        
        # MCS estimator
        return df * 1 / I * np.sum(V[1])
    
    def gen_sn(self, M, I, anti_paths=True, mo_match=True):
        '''
        M: number of time intervals for discretization
        I: number of paths to be simulated
        anti_paths: use of antithetic variates
        mo_math: use of moment matching
        '''
        
        if anti_paths is True:
            sn = npr.standard_normal((M + 1, I / 2))
            sn = np.concatenate((sn, -sn), axis = 1)
        else:
            sn = npr.standard_normal((M + 1, I))
            
        if mo_match is True:
            sn = (sn - sn.mean()) / sn.std()

        return sn