# -*- coding: utf-8 -*-
"""
Created on Thu Sep 17 18:38:11 2015

@author: ka
"""

from app.model.BlackScholesCalculator import *


class NewtonRaphson:
    MAX_TRY_VOLATILITY = 1000000
    MAX_TRY_LOOP = 2500
    MIN_PREMIUM_DELTA = 0.0002
    
    def calculateVolatility(self, premium, callOrPut, stockPrice, strike, expiryYears, dividendYield, riskfreeRate):
        if stockPrice <= 0 or premium <= 0 or expiryYears <= 0:
            #print 'input parameter values is invalid.'
            #print stockPrice, premium, expiryYears
            return 0.
            
        bs = BlackScholesCalculator()

        loopVolatility = self.initVolatility(stockPrice, expiryYears, premium)

        prevVolatility = loopVolatility / 2.
        loopPremium = bs.getValue(callOrPut, stockPrice, strike, prevVolatility, expiryYears, dividendYield, riskfreeRate)
        
        loop = 0
        while (1 == 1):
            loop += 1
            if loop > NewtonRaphson.MAX_TRY_LOOP: return -1.
            if loopVolatility > NewtonRaphson.MAX_TRY_VOLATILITY: return -1.
                
            newPrimium = bs.getValue(callOrPut, stockPrice, strike, loopVolatility, expiryYears, dividendYield, riskfreeRate)
            prevPrimium = loopPremium
            loopPrimium = newPrimium
            
            if newPrimium < 0.00001: return -1.
            
            if (abs(loopPrimium - premium) <= NewtonRaphson.MIN_PREMIUM_DELTA):
                return loopVolatility
            
            loopVega = bs.getVega(stockPrice, strike, loopVolatility, expiryYears, dividendYield, riskfreeRate) * 100
            newVolatility = self.adjVolatility(loopVolatility, loopPrimium, loopVega, prevVolatility, prevPrimium, premium)
            prevVolatility = loopVolatility
            loopVolatility = newVolatility
        
        
    def initVolatility(self, stockPrice, expiryYears, primium):
        initVol = math.sqrt(2. * math.pi / expiryYears) * primium / stockPrice
        return initVol;
        
    def adjVolatility(self, loopVolatility, loopPrimium, loopVega, prevVolatility, prevPrimium, primium):
        if (abs(loopVega) >= 0.0001):
            newVol = loopVolatility - (loopPrimium - primium) / loopVega
        else:
            newVol = loopVolatility - \
                (loopPrimium - primium) * (loopVolatility - prevVolatility) / (loopPrimium - prevPrimium)
            
        return newVol
        
        
        
