#!/usr/bin/env python3
import argparse
import gzip
import json
import numpy as np

"""
This script can be used to evaluate an IDS on the Sherlock Dataset. As described in the publication, Sherlock features multiple labels:
    - Benign: Normal behavior of the power grid. It is guaranteed that no effects of (prior) attacks remain. An IDS should not emit alerts during these periods.
    - Maintenance: An legitimate but irregular maintenance or control event changing the state of the power grid. An IDS should not emit alerts during these periods.
    - Attack: Duration during which the impact of an attack takes place. An IDS should indicate those regions with an alert. The start of the alert must be within the attack period, the earlier the better, but may overshoot the attack.
    - Recovery: Actions initiated to restore the power grid to its regular state. These regions are ignored during the IDS evaluation. Alerts starting here are not counted as TP nor as FP.
"""

def open_file(filename, mode="r"):
    """
    Used to transparently open regular and compressed .gz files.
    """

    if filename.endswith(".gz"):
        return gzip.open(filename, mode=mode)
    else:
        return open(filename, mode=mode)


def prepare_arg_parser(parser):

    parser.add_argument(
        "input",
        metavar="FILE",
        nargs=1,
        help="input file of IPAL messages to evaluate ('*.gz' compressed)",
    )
    parser.add_argument(
        "--attacks",
        dest="attacks",
        metavar="FILE",
        help="JSON file containing the attacks from the used dataset ('*.gz' compress) (Default: None)",
        required=True,
    )


def get_consequtive_alerts(dataset):
    """
    Extract the start and end of each alert.  An IDS emits its alerts on a per data-point basis as a labelled dataset. Yet, we are interested in continuous alerts.
    """

    alerts = []
    started = False
    end = None

    for d in dataset:
        if d["ids"] and not started:
            started = d["timestamp"]
        elif d["ids"] and started:
            end = d["timestamp"]
        elif not d["ids"] and started:
            alerts.append((started, end))
            started = False
            end = None

    if started:
        alerts.append((started, d["timestamp"]))

    return alerts

def get_detected_attacks(dataset, attacks):
    """
    Detected Attacks
    The absolute number of attacks during which an alarm starts. Alarms starting
    before the attack are considered false alarms as are alarms that only start
    after the attack (they most likely only detect the recovery to normal operation).

    Average TTD
    The average time in seconds from the start of an attack until the first alarm starts.
    """

    detected = []
    ttd = []

    for alert in get_consequtive_alerts(dataset):
        for attack in attacks:
            if attack["start"] <= alert[0] and alert[0] <= attack["end"]:

                if attack["id"] not in detected:
                    # Only count the first alert for each attack (only relevant for ttd).
                    ttd.append(alert[0] - attack["start"])
                    detected.append(attack["id"])

    return detected, np.mean(ttd)

def get_false_alarms(dataset, attacks, maintenances, recovery):
    """
    False Alarms
    The number of alarms that start during normal operations without an attack. Maintenance events are considered normal operation. In contrast, we ignore alerts during recovery phases.
    """

    false_alarms = 0
    during_maintenance = 0

    for alert in get_consequtive_alerts(dataset):
        for attack in attacks + recovery: # Filter alerts during attack (true positives) and recovery
            if attack["start"] <= alert[0] and alert[0] <= attack["end"]:
                break

        else:
            false_alarms += 1

            # Optionally: Check if alert occured during maintenance
            for maintenance in maintenances:
                if maintenance["start"] <= alert[0] and alert[0] <= maintenance["end"]:
                    during_maintenance += 1
                    break

    return false_alarms, during_maintenance

def evaluate(args):
    """
    Load attack.json files, parse labelled dataset, and calculate the required metrics.
    """

    # Load attack file and filter the different event types
    attacks = [] # Attack events
    recoveries = [] # Recovery phase after attack
    maintenances = [] # Maintenance events

    with open_file(args.attacks, "r") as f:
        for event in json.load(f):
            if "benign event" in event["id"]:
                maintenances.append(event)

            else:
                attacks.append(event)
                recoveries.append({
                    "start": event["end"],
                    "end": event["recovery"],
                })

    # Load labelled dataset
    dataset = []
    REQUIRED_KEYS = ["timestamp", "ids"]
    with open_file(args.input[0], "r") as f:
        for line in f:
            js = json.loads(line)
            dataset.append({key: js[key] for key in REQUIRED_KEYS if key in js})

    # Evaluate
    detected_attacks, attd = get_detected_attacks(dataset, attacks)
    false_alarms, during_benign = get_false_alarms(dataset, attacks, maintenances, recoveries)

    # Output results
    erg = {
        "attacks": len(attacks),
        "detected-attacks": len(detected_attacks),
        "detected": detected_attacks,
        "false-alarms": false_alarms,
        "alarm-during-maintenance": during_benign,
        "average-ttd": attd,
    }
    print(json.dumps(erg, indent=4))

def main():
    parser = argparse.ArgumentParser()
    prepare_arg_parser(parser)
    args = parser.parse_args()

    evaluate(args)


if __name__ == "__main__":
    main()

