#!/usr/bin/python

#
# ===========================================================================
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Copyright (c) 2011-2015, Marvell International Ltd.
#
# Alternatively, this software may be distributed under the terms of the GNU
# General Public License Version 2, and any use shall comply with the terms and
# conditions of the GPL.  A copy of the GPL is available at
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.html
#
# THE FILE IS DISTRIBUTED AS-IS, WITHOUT WARRANTY OF ANY KIND, AND THE
# IMPLIED WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE
# ARE EXPRESSLY DISCLAIMED.  The GPL license provides additional details about
# this warranty disclaimer.
# ================================================================================

# 
# Take a part of list of Q60 output data and make individual GNUplots.
# Different fields will be sorted differently according to how things are laid
# out on the Q60. For example, column 13 (Cyan of constant chroma but
# decreasing luminance) will be made into one graph.
#
# Adding plot of fleshtones. davep 25-Sep-2009
#

import sys
import math
import os
import logging
from matplotlib.figure import Figure, SubplotParams
from matplotlib.backends.backend_agg import FigureCanvasAgg

dlog = logging.getLogger("q60plot")

output = "Lab"
#output = "YCC"
if output=="Lab" :
    # CIELAB
    title_list = ("L", "a", "b")
elif output=="YCC" :
    # YCC
    title_list = ("Y", "Cb", "Cr")

q60_rows = ( 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L' )

# we want to turn the columns in our data file into elements in a hash; this is
# the name of the data as output from csv2dat.py as of 23-Aug-06
datfile_column_names = ( "SAMPLE_ID", "Q60_L", "Q60_a", "Q60_b", "Test_L",
    "Test_a", "Test_b" )

class BadFile( Exception ) :
    def __init__( self, msg, filename, line_number ) :
        self.filename = filename
        self.line_number = line_number
        self.msg = msg

def read_datafile( filename ) :

    fd = open( filename, "r" )

    data_dict = {}

    line_number = 0
    while 1 :
        line = fd.readline()
        if len(line) <= 0 :
            break

        line_number += 1
        line = line.strip()
        # ignore blank lines and comment lines
        if len(line) == 0  :
            continue
        if line[0] == '#' :
            continue

        fields = line.split()
        if len(fields) < 7 :
            raise BadFile( "Wrong number of fields: found %d, should be %d." % (len(fields),7), 
                filename, line_number )

        # no duplicates allowed
        assert fields[0] not in data_dict, fields[0]

        data_dict[fields[0]] = { "target" : [float(s) for s in fields[1:4]],
                                 "test"   : [float(s) for s in fields[4:7]],
                                "delta-e" : float(fields[7]) } 

    fd.close()
    return data_dict

def make_figure(num_rows, num_cols):
    # shrink the surrounding area
    params = SubplotParams()
    params.left = 0.05
    params.right = 0.95
    params.top = 0.95
    params.bottom = 0.05

    # figsize is width,height in inches
    figsize = (num_cols*11,num_rows*8)

    fig = Figure(figsize=figsize, dpi=300, subplotpars=params )

    return fig

def axplot( ax, test_data, target_data, title ) : 
    ax.grid()
    ax.margins(0.05,0.05)

    ax.plot(test_data,'gx--',label='test')
    ax.stem(target_data,linefmt='r-',markerfmt='r.',basefmt='r--',label='target')

    ax.legend(frameon=False)

    ax.set_title(title)

def plot_columns( data_dict, patch_column_list, color_names, outfilename ) :

    dlog.debug("plot_columns")

    assert len(patch_column_list)==len(color_names),(len(patch_column_list),len(color_names))

    subplot_num_rows = 3
    subplot_num_cols = len(patch_column_list)
    subplot_col_counter = 1

    fig = make_figure(subplot_num_rows,subplot_num_cols)

    for colnum,color_name in zip(patch_column_list,color_names) : 
        column_names = ["{0}{1:02}".format(rowname,colnum) for rowname in q60_rows]
        dlog.debug("{0} {1}".format(color_name, " ".join(column_names)))

        for idx,title in enumerate(title_list) : 
            # [0] == L or Y
            # [1] == a or Cb
            # [2] == b or Cr
            target_data = [ data_dict[patch_name]["target"][idx] for patch_name in column_names ]
            test_data = [ data_dict[patch_name]["test"][idx] for patch_name in column_names ]
            
            # matplotlib counts column first but we want row first so jump
            # a few hoops
            subplot_num = subplot_col_counter + (idx*subplot_num_cols)
            ax = fig.add_subplot( subplot_num_rows, subplot_num_cols, subplot_num )

            s = "{0} {1} of {2}".format(
                color_name, title, output )
            dlog.info(s)

            axplot(ax,test_data,target_data, s)

        subplot_col_counter += 1

    canvas = FigureCanvasAgg(fig)
    canvas.print_figure(outfilename)
    dlog.info("wrote {0}".format(outfilename))

def plot_patches( data_dict, patches_list, outfilename ):
    dlog.debug("plot_patches {0}".format( " ".join(patches_list) ) )

    # one col, three rows
    fig = make_figure( 3, 1 )

    for idx,title in enumerate(title_list) : 
        # [0] == L or Y
        # [1] == a or Cb
        # [2] == b or Cr
        target_data = [ data_dict[patch_name]["target"][idx] for patch_name in patches_list ]
        test_data = [ data_dict[patch_name]["test"][idx] for patch_name in patches_list ]
           
        ax = fig.add_subplot( 3,1,idx+1 )

        s = "{0}..{1} {2} of {3}".format( patches_list[0], patches_list[-1], title, output )
        dlog.info(s)

        axplot(ax,test_data,target_data, s )

    canvas = FigureCanvasAgg(fig)
    canvas.print_figure(outfilename)
    dlog.info("wrote {0}".format(outfilename))

def plot( data_dict ):
    # data_dict is a dict (duh) keyed by the patchname. Value is another
    # dictionary:
    #   "target" - array of 3 floats, the target Lab values (from a .Q60 file)
    #   "test" - array of 3 floats, the test Lab values

    if 1:
        # columns [13-19] are the CMYK,RGB patches
        patches = list(range(13,20))
        names = ( "Cyan", "Magenta", "Yellow", "Black", "Red", "Green", "Blue" )
        plot_columns( data_dict, patches, names, "cmykrgb.png" )

        patches = list(range(1,13))
        plot_columns(data_dict, patches, q60_rows, "AtoL.png" )

    gs_patches = [ "GS{0}".format(n) for n in range(24) ]
    plot_patches( data_dict, gs_patches, "gs.png" )

    fleshtones = ( "I20", "I21", "I22", 
                   "J20", "J21", "J22",
                   "K20", "K21", "K22",
                   "L20", "L21", "L22" )
    plot_patches( data_dict, fleshtones, "fleshtones.png" )


def main() :
    infilename = sys.argv[1]

    try : 
        data_dict = read_datafile( infilename )
    except BadFile as e :
        print("Error in file",e.filename, "at line",e.line_number,"::", e.msg)
        sys.exit(1)

    plot(data_dict)

if __name__ == '__main__':
#    fmt = "%(filename)s %(lineno)d %(name)s %(message)s"
    fmt = "%(filename)s %(name)s %(message)s"
#    logging.basicConfig( level=logging.INFO, format=fmt )
    logging.basicConfig( level=logging.DEBUG, format=fmt )

    main()

