Deshwal
Deshwal

Reputation: 4152

How to add multiple Moving Average lines DYNAMICALLY to Plotly Candlestick

I have written this function to plot a candlestick using plotly. I want to add the functionality of Adding N different lines dynamically with different colors. Now it is adding just 1 line. I can hard code another Scatter but what not dynamically.

Here's the code:

def plot_candlesticks(df, names = ('DATE','OPEN','CLOSE','LOW','HIGH'), mv = 44):
        '''
        Plot a candlestick on a given dataframe
        args:
            df: DataFrame
            names: Tuple of column names showing ('DATE','OPEN','HIGH','LOW','OPEN','CLOSE')
            mv: Moving Average
        '''
        stocks = df.copy()
        Date, Open, Close, Low, High = names
        stocks.sort_index(ascending=False, inplace = True)
        stocks[f'{str(mv)}-SMA'] = stocks[Close].rolling(mv, min_periods = 1).mean()

        candle = go.Figure(data = [go.Candlestick(x = stocks[Date], name = 'Trade',
                                                       open = stocks[Open], 
                                                       high = stocks[High], 
                                                       low = stocks[Low], 
                                                       close = stocks[Close]),

                                  go.Scatter(name=f'{str(mv)} MA',x=stocks[Date], y=stocks[f'{str(mv)}-SMA'], 
                                             line=dict(color='blue', width=1)),])

        candle.update_xaxes(
            title_text = 'Date',
            rangeslider_visible = True,
            rangeselector = dict(
                buttons = list([
                    dict(count = 1, label = '1M', step = 'month', stepmode = 'backward'),
                    dict(count = 6, label = '6M', step = 'month', stepmode = 'backward'),
                    dict(count = 1, label = 'YTD', step = 'year', stepmode = 'todate'),
                    dict(count = 1, label = '1Y', step = 'year', stepmode = 'backward'),
                    dict(step = 'all')])))

        candle.update_layout(autosize = True,
                             title = {'text': all_stocks[stocks['SYMBOL'][0]],'y':0.97,'x':0.5,
                                      'xanchor': 'center','yanchor': 'top'},
                             margin=dict(l=30,r=30,b=30,t=30,pad=2),
                             paper_bgcolor="lightsteelblue",)

        candle.update_yaxes(title_text = 'Close Price', tickprefix = u"\u20B9" ) # Rupee symbol
        candle.show()

Upvotes: 0

Views: 1146

Answers (1)

r-beginners
r-beginners

Reputation: 35205

I have modified the code with the understanding that the intent of your question is to automatically add multiple moving averages to the candlestick graph, with the base value passed to the list as the function argument. The point is achieved by add_trace() to the candlestick graph.

import plotly.graph_objects as go
import pandas as pd
import yfinance as yf

data = yf.download("AAPL", start="2021-01-01", end="2021-03-01")
data = data.iloc[:,0:4]
data.reset_index(inplace=True)

def plot_candlesticks(df, names = ('Date','Open','High','Low','Close'), mv = [5,25,75]):
        '''
        Plot a candlestick on a given dataframe
        args:
            df: DataFrame
            names: Tuple of column names showing ('DATE','OPEN','HIGH','LOW','OPEN','CLOSE')
            mv: Moving Average
        '''
        stocks = df.copy()
        Date, Open, Close, Low, High = names
        stocks.sort_index(ascending=False, inplace = True)
        colors = ['red', 'blue', 'yellow']

        candle = go.Figure(data = [go.Candlestick(x = stocks[Date], name = 'Trade',
                                                       open = stocks[Open], 
                                                       high = stocks[High], 
                                                       low = stocks[Low], 
                                                       close = stocks[Close]),])
        for i in range(len(mv)):
            stocks[f'{str(mv[i])}-SMA'] = stocks[Close].rolling(mv[i], min_periods = 1).mean()
            candle.add_trace(go.Scatter(name=f'{str(mv[i])} MA',x=stocks[Date], y=stocks[f'{str(mv[i])}-SMA'], 
                                             line=dict(color=colors[i], width=2)))

        candle.update_xaxes(
            title_text = 'Date',
            rangeslider_visible = True,
            rangeselector = dict(
                buttons = list([
                    dict(count = 1, label = '1M', step = 'month', stepmode = 'backward'),
                    dict(count = 6, label = '6M', step = 'month', stepmode = 'backward'),
                    dict(count = 1, label = 'YTD', step = 'year', stepmode = 'todate'),
                    dict(count = 1, label = '1Y', step = 'year', stepmode = 'backward'),
                    dict(step = 'all')])))

        candle.update_layout(autosize = True,
                             title = {'text': "all_stocks[stocks['SYMBOL'][0]]",'y':0.97,'x':0.5,
                                      'xanchor': 'center','yanchor': 'top'},
                             margin=dict(l=30,r=30,b=30,t=30,pad=2),
                             paper_bgcolor="lightsteelblue",)

        candle.update_yaxes(title_text = 'Close Price', tickprefix = u"\u20B9" ) # Rupee symbol
        candle.show()

plot_candlesticks(data)

enter image description here

Upvotes: 1

Related Questions