#!/usr/bin/env python3
"""psi2log"""

import os
from time import sleep, monotonic
from ctypes import CDLL
from sys import stdout, exit
from argparse import ArgumentParser
from signal import signal, SIGTERM, SIGINT, SIGQUIT, SIGHUP


def form(num):
    """
    """
    s = str(num).split('.')
    return '{}.{:0<2}'.format(s[0], s[1])


def signal_handler(signum, frame):
    """
    """
    def signal_handler_inner(signum, frame):
        pass
    for i in sig_list:
        signal(i, signal_handler_inner)

    log('')

    lpd = len(peaks_dict)
    if lpd != 15:
        exit()

    log('=================================')
    log('Peak values:  avg10  avg60 avg300')

    log('-----------  ------ ------ ------')

    log('some cpu     {:>6} {:>6} {:>6}'.format(
        form(peaks_dict['c_some_avg10']),
        form(peaks_dict['c_some_avg60']),
        form(peaks_dict['c_some_avg300']),
    ))

    log('-----------  ------ ------ ------')

    log('some memory  {:>6} {:>6} {:>6}'.format(
        form(peaks_dict['m_some_avg10']),
        form(peaks_dict['m_some_avg60']),
        form(peaks_dict['m_some_avg300']),
    ))

    log('full memory  {:>6} {:>6} {:>6}'.format(
        form(peaks_dict['m_full_avg10']),
        form(peaks_dict['m_full_avg60']),
        form(peaks_dict['m_full_avg300']),
    ))

    log('-----------  ------ ------ ------')

    log('some io      {:>6} {:>6} {:>6}'.format(
        form(peaks_dict['i_some_avg10']),
        form(peaks_dict['i_some_avg60']),
        form(peaks_dict['i_some_avg300']),
    ))

    log('full io      {:>6} {:>6} {:>6}'.format(
        form(peaks_dict['i_full_avg10']),
        form(peaks_dict['i_full_avg60']),
        form(peaks_dict['i_full_avg300']),
    ))

    exit()


def cgroup2_root():
    """
    """
    with open(mounts) as f:
        for line in f:
            if cgroup2_separator in line:
                return line.partition(cgroup2_separator)[0].partition(' ')[2]


def mlockall():
    """
    """

    MCL_CURRENT = 1
    MCL_FUTURE = 2
    MCL_ONFAULT = 4

    CDLL('libc.so.6').mlockall(MCL_CURRENT | MCL_FUTURE | MCL_ONFAULT)


def psi_file_mem_to_metrics(psi_path):
    """
    """

    with open(psi_path) as f:
        psi_list = f.readlines()
    some_list, full_list = psi_list[0].split(' '), psi_list[1].split(' ')
    some_avg10 = some_list[1].split('=')[1]
    some_avg60 = some_list[2].split('=')[1]
    some_avg300 = some_list[3].split('=')[1]
    full_avg10 = full_list[1].split('=')[1]
    full_avg60 = full_list[2].split('=')[1]
    full_avg300 = full_list[3].split('=')[1]
    return (some_avg10, some_avg60, some_avg300,
            full_avg10, full_avg60, full_avg300)


def psi_file_cpu_to_metrics(psi_path):
    """
    """

    with open(psi_path) as f:
        psi_list = f.readlines()
    some_list = psi_list[0].split(' ')
    some_avg10 = some_list[1].split('=')[1]
    some_avg60 = some_list[2].split('=')[1]
    some_avg300 = some_list[3].split('=')[1]
    return (some_avg10, some_avg60, some_avg300)


def psi_file_mem_to_total(psi_path):
    """
    """

    with open(psi_path) as f:
        psi_list = f.readlines()
    some_list, full_list = psi_list[0].split(' '), psi_list[1].split(' ')
    some_total = some_list[4].split('=')[1]
    full_total = full_list[4].split('=')[1]

    return int(some_total), int(full_total)


def psi_file_cpu_to_total(psi_path):
    """
    """

    with open(psi_path) as f:
        psi_list = f.readlines()
    some_list = psi_list[0].split(' ')
    some_total = some_list[4].split('=')[1]

    return int(some_total)


def print_head_1():
    log('=================================================================='
        '================================================')
    log(' some cpu pressure   || some memory pressure | full memory pressur'
        'e ||   some io pressure   |   full io pressure')
    log('-------------------- || -------------------- | -------------------'
        '- || -------------------- | --------------------')
    log(' avg10  avg60 avg300 ||  avg10  avg60 avg300 |  avg10  avg60 avg30'
        '0 ||  avg10  avg60 avg300 |  avg10  avg60 avg300')
    log('------ ------ ------ || ------ ------ ------ | ------ ------ -----'
        '- || ------ ------ ------ | ------ ------ ------')


def print_head_2():
    log('============================================')
    log(' cpu  |    memory   |     io      |')
    log('----- | ----------- | ----------- |')
    log(' some |  some  full |  some  full | interval')
    log('----- | ----- ----- | ----- ----- | --------')


def log(*msg):
    """
    """
    print(*msg)
    if separate_log:
        logging.info(*msg)


parser = ArgumentParser()

parser.add_argument(
    '-t',
    '--target',
    help="""target (cgroup_v2 or SYTSTEM_WIDE)""",
    default='SYSTEM_WIDE',
    type=str
)


parser.add_argument(
    '-i',
    '--interval',
    help="""interval in sec""",
    default=2,
    type=float
)


parser.add_argument(
    '-l',
    '--log',
    help="""path to log file""",
    default=None,
    type=str
)


parser.add_argument(
    '-m',
    '--mode',
    help="""mode (1 or 2)""",
    default='1',
    type=str
)


args = parser.parse_args()
target = args.target
interval = args.interval
log_file = args.log
mode = args.mode


if log_file is None:
    separate_log = False
else:
    separate_log = True
    import logging

sig_list = [SIGTERM, SIGINT, SIGQUIT, SIGHUP]


for i in sig_list:
    signal(i, signal_handler)


if separate_log:
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format="%(asctime)s: %(message)s")

log('Starting psi2log')
log('target: {}'.format(target))
log('interval: {} sec'.format(interval))
if log_file is not None:
    log('log file: {}'.format(log_file))
log('mode: {}'.format(mode))


try:
    psi_file_mem_to_metrics('/proc/pressure/memory')
except Exception as e:
    print('ERROR: {}'.format(e))
    print('PSI metrics are not provided by the kernel. Exit.')
    exit(1)


if target == 'SYSTEM_WIDE':

    cpu_file = "/proc/pressure/cpu"
    memory_file = "/proc/pressure/memory"
    io_file = "/proc/pressure/io"

else:

    mounts = '/proc/mounts'
    cgroup2_separator = ' cgroup2 rw,'
    cgroup2_mountpoint = cgroup2_root()

    if cgroup2_mountpoint is None:

        log('ERROR: unified cgroup hierarchy is not mounted, exit')
        exit(1)

    else:

        log('cgroup_v2 mountpoint: {}'.format(cgroup2_mountpoint))

    cpu_file = cgroup2_mountpoint + target + "/cpu.pressure"
    memory_file = cgroup2_mountpoint + target + "/memory.pressure"
    io_file = cgroup2_mountpoint + target + "/io.pressure"


peaks_dict = dict()


mlockall()


if mode == '2':

    print_head_2()

    total_cs0 = psi_file_cpu_to_total(cpu_file)
    total_ms0, total_mf0 = psi_file_mem_to_total(memory_file)
    total_is0, total_if0 = psi_file_mem_to_total(io_file)
    monotonic0 = monotonic()
    sleep(interval)

    while True:

        total_cs1 = psi_file_cpu_to_total(cpu_file)
        total_ms1, total_mf1 = psi_file_mem_to_total(memory_file)
        total_is1, total_if1 = psi_file_mem_to_total(io_file)
        monotonic1 = monotonic()
        dm = monotonic1 - monotonic0
        monotonic0 = monotonic1

        dtotal_cs = total_cs1 - total_cs0
        avg_cs = dtotal_cs / dm / 10000
        total_cs0 = total_cs1

        dtotal_ms = total_ms1 - total_ms0
        avg_ms = dtotal_ms / dm / 10000
        total_ms0 = total_ms1

        dtotal_mf = total_mf1 - total_mf0
        avg_mf = dtotal_mf / dm / 10000
        total_mf0 = total_mf1

        dtotal_is = total_is1 - total_is0
        avg_is = dtotal_is / dm / 10000
        total_is0 = total_is1

        dtotal_if = total_if1 - total_if0
        avg_if = dtotal_if / dm / 10000
        total_if0 = total_if1

        log('{:>5} | {:>5} {:>5} | {:>5} {:>5} | {}'.format(

            round(avg_cs, 1),

            round(avg_ms, 1),
            round(avg_mf, 1),

            round(avg_is, 1),
            round(avg_if, 1),

            round(dm, 3)

        ))

        stdout.flush()
        sleep(interval)


if mode != '1':
    log('ERROR: invalid mode. Exit.')
    exit(1)


print_head_1()


while True:

    if not os.path.exists(cpu_file):
        log('ERROR: cpu pressure file does not exist:    {}'.format(cpu_file))
        sleep(interval)
        continue

    if not os.path.exists(memory_file):
        log('ERROR: memory pressure file does not exist: {}'.format(
            memory_file))
        sleep(interval)
        continue

    if not os.path.exists(io_file):
        log('ERROR: io pressure file does not exist:     {}'.format(cpu_file))
        sleep(interval)
        continue

    (c_some_avg10, c_some_avg60, c_some_avg300
     ) = psi_file_cpu_to_metrics(cpu_file)

    (m_some_avg10, m_some_avg60, m_some_avg300,
     m_full_avg10, m_full_avg60, m_full_avg300
     ) = psi_file_mem_to_metrics(memory_file)

    (i_some_avg10, i_some_avg60, i_some_avg300,
     i_full_avg10, i_full_avg60, i_full_avg300
     ) = psi_file_mem_to_metrics(io_file)

    log('{:>6} {:>6} {:>6} || {:>6} {:>6} {:>6} | {:>6} {:>6} {:>6} || {:>6}'
        ' {:>6} {:>6} | {:>6} {:>6} {:>6}'.format(

            c_some_avg10, c_some_avg60, c_some_avg300,

            m_some_avg10, m_some_avg60, m_some_avg300,
            m_full_avg10, m_full_avg60, m_full_avg300,

            i_some_avg10, i_some_avg60, i_some_avg300,
            i_full_avg10, i_full_avg60, i_full_avg300

        ))

    c_some_avg10 = float(c_some_avg10)
    if ('c_some_avg10' not in peaks_dict or
            peaks_dict['c_some_avg10'] < c_some_avg10):
        peaks_dict['c_some_avg10'] = c_some_avg10

    c_some_avg60 = float(c_some_avg60)
    if ('c_some_avg60' not in peaks_dict or
            peaks_dict['c_some_avg60'] < c_some_avg60):
        peaks_dict['c_some_avg60'] = c_some_avg60

    c_some_avg300 = float(c_some_avg300)
    if ('c_some_avg300' not in peaks_dict or
            peaks_dict['c_some_avg300'] < c_some_avg300):
        peaks_dict['c_some_avg300'] = c_some_avg300

    #######################################################################

    m_some_avg10 = float(m_some_avg10)
    if ('m_some_avg10' not in peaks_dict or
            peaks_dict['m_some_avg10'] < m_some_avg10):
        peaks_dict['m_some_avg10'] = m_some_avg10

    m_some_avg60 = float(m_some_avg60)
    if ('m_some_avg60' not in peaks_dict or
            peaks_dict['m_some_avg60'] < m_some_avg60):
        peaks_dict['m_some_avg60'] = m_some_avg60

    m_some_avg300 = float(m_some_avg300)
    if ('m_some_avg300' not in peaks_dict or
            peaks_dict['m_some_avg300'] < m_some_avg300):
        peaks_dict['m_some_avg300'] = m_some_avg300

    m_full_avg10 = float(m_full_avg10)
    if ('m_full_avg10' not in peaks_dict or
            peaks_dict['m_full_avg10'] < m_full_avg10):
        peaks_dict['m_full_avg10'] = m_full_avg10

    m_full_avg60 = float(m_full_avg60)
    if ('m_full_avg60' not in peaks_dict or
            peaks_dict['m_full_avg60'] < m_full_avg60):
        peaks_dict['m_full_avg60'] = m_full_avg60

    m_full_avg300 = float(m_full_avg300)
    if ('m_full_avg300' not in peaks_dict or
            peaks_dict['m_full_avg300'] < m_full_avg300):
        peaks_dict['m_full_avg300'] = m_full_avg300

    #######################################################################

    i_some_avg10 = float(i_some_avg10)
    if ('i_some_avg10' not in peaks_dict or
            peaks_dict['i_some_avg10'] < i_some_avg10):
        peaks_dict['i_some_avg10'] = i_some_avg10

    i_some_avg60 = float(i_some_avg60)
    if ('i_some_avg60' not in peaks_dict or
            peaks_dict['i_some_avg60'] < i_some_avg60):
        peaks_dict['i_some_avg60'] = i_some_avg60

    i_some_avg300 = float(i_some_avg300)
    if ('i_some_avg300' not in peaks_dict or
            peaks_dict['i_some_avg300'] < i_some_avg300):
        peaks_dict['i_some_avg300'] = i_some_avg300

    i_full_avg10 = float(i_full_avg10)
    if ('i_full_avg10' not in peaks_dict or
            peaks_dict['i_full_avg10'] < i_full_avg10):
        peaks_dict['i_full_avg10'] = i_full_avg10

    i_full_avg60 = float(i_full_avg60)
    if ('i_full_avg60' not in peaks_dict or
            peaks_dict['i_full_avg60'] < i_full_avg60):
        peaks_dict['i_full_avg60'] = i_full_avg60

    i_full_avg300 = float(i_full_avg300)
    if ('i_full_avg300' not in peaks_dict or
            peaks_dict['i_full_avg300'] < i_full_avg300):
        peaks_dict['i_full_avg300'] = i_full_avg300

    stdout.flush()
    sleep(interval)
