import pandas as pd
import sys
from pathlib import Path
import os
from MeasurementPlot import MeasurementPlot
import numpy as np

class PostPlot:
    """
    Generate a plot of all data files from the individual measurements to have one plot of the whole
    data taking sequence.
    Optionally, correlation and regression can be calculated and plotted.
    """
    def __init__(self, reference_signal_names, trace_subplot5='', legend_loc='upper left', legend_bbox_to_anchor=(1.09, 1)):
        # set python for opening an separate plot window
        if 'ipykernel' in sys.modules:
            from IPython import get_ipython
            get_ipython().run_line_magic('matplotlib', 'qt')
        
        self.measplot = MeasurementPlot(reference_signal_names, trace_subplot5=trace_subplot5,
                                        legend_loc=legend_loc, legend_bbox_to_anchor=legend_bbox_to_anchor)
        
        self.legend_loc = legend_loc
        self.legend_bbox_to_anchor = legend_bbox_to_anchor
        
        # set parameter figure of class to object parameter figure
        self.fig = self.measplot.fig

        self.reference_signal_names = reference_signal_names
    
    # read csv-file and import data to data frame
    def import_csv(self, csv_file):
        
        data_frame = pd.read_csv(csv_file)
        
        return data_frame
                
    def plot_frame_data(self, data_frame, title, time_unit ='min', measurement_set=None, ):
        
        # set title of plot 
        self.fig.suptitle("Measurement "+title, color="red")
        
        # reset index of data frame
        data_frame.reset_index(inplace=True, drop=True)
        
        # set y-Achlabel in all subplots to Time an in square brackets the selected time unit 
        for element in self.measplot.ax1:
            element.set_xlabel("Time [%s]" %time_unit)         
        
        # make a copy of data_frame in parameter list without changing original during modification
        if measurement_set is None:
            postplot_data_frame = data_frame.copy()
        else:
            postplot_data_frame = data_frame.loc[data_frame['SET_NAME'] == measurement_set].copy().reset_index()

        # time stamp of index = 0 is the start point of time axis
        time_sec = postplot_data_frame.TIMESTAMP - postplot_data_frame.TIMESTAMP[0]
        
        # Set scaling of time axis depending on the time unit given in parameter list.
        # Default unit is minutes
        time_vals = self.scaling_time_axes(time_sec, time_unit)
        
        # update Timestamps with calculated time values         
        postplot_data_frame.update(time_vals) 
        
        # refresh subplots with data in data frame
        self.measplot.draw_in_this_thread(postplot_data_frame, pdf_name='')
        
        # cal PK2PK values of magnitude and phase 
        PK2PK = self.calc_pkpk_values(postplot_data_frame)

        self.edit_annotation_in_plot(annotate_string=PK2PK)
        
    def edit_annotation_in_plot(self, annotate_string='', anno_fontsize=16):
        
        self.measplot.annotation.set_text(annotate_string)
        self.measplot.annotation.set_fontsize(anno_fontsize)

    def calc_pkpk_values(self, data_frame):
        
        # calc PK2PK values of the two reference signals
        delta0 = max(data_frame[self.reference_signal_names[0]]) - min(data_frame[self.reference_signal_names[0]])
        delta1 = max(data_frame[self.reference_signal_names[1]]) - min(data_frame[self.reference_signal_names[1]])
        
        # generate text for annotation in first subplot
        # FIXME: This used to contain units. We will need some 'pretty printing' values for the nanes
        delta_vals_string = ('$\Delta_{PkPk}$'+self.reference_signal_names[0]+': '+str(delta0)+'\n$\Delta_{PkPk}$'+
                             self.reference_signal_names[1]+': '+str(delta1))
            
        return delta_vals_string
       
    def add_curvefit_to_plot(self, y_lim=None, xvals=[], yvals=[], trace_color='None', trace_label=''):
        if y_lim is not None:
            # set axis for phase plot to min, max values
            self.measplot.ax1[0].set_ylim(y_lim)
            
        self.measplot.path_collection_fit.set_color(trace_color)
        self.measplot.path_collection_fit.set_label(trace_label)

        # refresh data in meas plot
        self.measplot.path_collection_fit.set_offsets(np.c_[xvals, yvals])
        
        # get legend handles and labels y-axes from subplot magnitude, phase
        handles_phase, labels_phase = self.measplot.ax1[0].get_legend_handles_labels()
        handles_mag, labels_mag = self.measplot.magnitude_axis.get_legend_handles_labels()
        handles_equi0, labels_eqi0 = self.measplot.equi_axis0.get_legend_handles_labels()
            
        handles = handles_phase + handles_mag + handles_equi0
        labels = labels_phase + labels_mag + labels_eqi0
        
        # update legend subplot phase, magnitude
        self.measplot.ax1[0].legend(handles, labels, loc=self.legend_loc,
                                    bbox_to_anchor=self.legend_bbox_to_anchor)
        
        # refresh plot window
        self.measplot.fig.canvas.flush_events()
    
    # save figure under the given path and file name
    def save_fig(self, storepath, filename):
        if not os.path.exists(storepath):
            os.makedirs(storepath)
        
        self.fig.savefig(os.path.join(storepath, filename))
    
    # scaling time axes depending on the time unit
    def scaling_time_axes(self, time_sec, time_unit):
        if time_unit == 'min':
            time_vals = time_sec/60
        
        elif time_unit == 'hours':
            time_vals = time_sec/3600
        
        elif time_unit == 'sec':
            time_vals = time_sec
            
        else:
            time_vals = time_sec/60
    
        return time_vals
    

if __name__ == '__main__':
    
    
    # set result path for post plot 
    Results_Path = r'TestData_JBY240'
    
    time_unit = 'min'
    
    storepath = os.path.join(Results_Path, 'PostPlots')
    
    # search all csv files in results folder 
    csv_file_list = list(Path(Results_Path).glob("**/*.csv"))
    
    # selection measurment data should be plotted in subplot5
    # 'logger_sens' : values for temperature and humidity of logger sensor in
    # measurement instrument chmaber
    # 'chamber_sens' : values for temepratere and humidity readback from chamber sensors
    # of chamber for measurement instruments
    # 'heater_dut_chamber': activity of temp heater and hum heater readback from DUT chamber
    # This is also the default parameter
    
    
    trace_selection = ""
    
    plot_obj = PostPlot(trace_subplot5=trace_selection)
    
    # empty data frame for concat the data frames from csv import to plot full transition
    concat_data_frame = pd.DataFrame()
    
    # plot results for each csv-file
    for index, csv_file in enumerate(csv_file_list):
    
        data_frame = plot_obj.import_csv(str(csv_file))
        title = csv_file.name
        
        # concatenate data frames for plotting full transition data
        concat_data_frame = pd.concat([concat_data_frame, data_frame], ignore_index=True, sort=False)

        plot_obj.plot_frame_data(data_frame, title, time_unit)

        filename = str(csv_file.stem) + '.pdf'
        plot_obj.save_fig(storepath, filename)
    
    # plot of all steps of a sweep is plotted in one diagram
    plot_obj.plot_frame_data(concat_data_frame, 'Full Transition', time_unit)
    
    filename = 'Full_Transition' + '.pdf'
    plot_obj.save_fig(storepath, filename)