"""
##################################### TERMS OF USE ###########################################
# The following code is provided for demonstration purpose only, and should not              #
# be used without independent verification. Recorded Future makes no representations         #
# or warranties, express, implied, statutory, or otherwise, regarding any aspect of          #
# this code or of the information it may retrieve, and provides it both strictly “as-is”     #
# and without assuming responsibility for any information it may retrieve. Recorded Future   #
# shall not be liable for, and you assume all risk of using, the foregoing. By using this    #
# code, Customer represents that it is solely responsible for having all necessary licenses, #
# permissions, rights, and/or consents to connect to third party APIs, and that it is solely #
# responsible for having all necessary licenses, permissions, rights, and/or consents to     #
# any data accessed from any third party API.                                                #
##############################################################################################
"""

import datetime
import logging
import argparse
from logging.handlers import RotatingFileHandler
import os
import requests

APP_ID = 'MDE-Collective_Insights'
APP_VERSION = '1.0.0'
COLLECTIVE_INSIGHTS_API_URL = "https://api.recordedfuture.com/collective-insights/detections"
MSFT_GRAPH_URL = 'https://graph.microsoft.com/v1.0'

class RecordedFutureLogger:
    """Rotating logger for capturing any diagnostic messages."""

    logging.basicConfig(filename="recordedfuture-cs-collectiveInsights.log", filemode="a")
    logger = logging.getLogger(__name__)
    file_handler = RotatingFileHandler(
        "recordedfuture-cs-collectiveInsights.log", maxBytes=10000000, backupCount=5
    )
    formatter_file = logging.Formatter(
        fmt="%(asctime)s,%(msecs)03d [%(threadName)s] %(levelname)s [%(module)s] "
        + "%(funcName)s:%(lineno)d - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    console_handler = logging.StreamHandler()
    formatter_console = logging.Formatter(
        fmt="%(asctime)s,%(msecs)03d %(levelname)s [%(module)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter_file)
    console_handler.setFormatter(formatter_console)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    logger.propagate = False

LOG = RecordedFutureLogger.logger

class MSFTClient:
    def __init__(self, tenant, client_id, client_secret, lookback, custom_filter):
        self.base_url = MSFT_GRAPH_URL
        self.client_id = client_id
        self.client_secret = client_secret
        self.tenant = tenant
        self.bearer_token = self._get_bearer_token()
        self.lookback = lookback
        self.custom_filter = custom_filter

    def _get_bearer_token(self):
        """Get Ouath2 bearer token"""
        LOG.debug('fetching Access token from Microsoft')

        url = f'https://login.microsoftonline.com/{self.tenant}/oauth2/v2.0/token'
        headers = {'Content-Type': 'application/x-www-form-urlencoded'}
        payload = {
            'client_id': self.client_id,
            'scope': 'https://graph.microsoft.com/.default',
            'grant_type': 'client_credentials',
            'client_secret': self.client_secret
        }
        res = requests.get(url, headers, data=payload)
        res.raise_for_status()
        return res.json()['access_token']
    def refresh_bearer_token(self):
        self.bearer_token = self._get_bearer_token()

    def list_v2_alerts(self):
        LOG.debug('fetching alerts from Defender')
        url = f'{self.base_url}/security/alerts_v2'
        lookback = (datetime.datetime.now() - datetime.timedelta(days=self.lookback)).isoformat() + 'Z'
        filter_ = f"createdDateTime gt {lookback}"
        if self.custom_filter:
            filter_ += " and {}".format(self.custom_filter)
        params = {
            "$filter": filter_

        }
        headers = {'Authorization': f'Bearer {self.bearer_token}'}
        res = requests.get(url, params=params, headers=headers)
        res.raise_for_status()
        return res.json()

def get_type(evidence):
    """Converts type to RF type and extracts IOC"""
    t = evidence['@odata.type']
    if t == '#microsoft.graph.security.urlEvidence':
        ioc = evidence['url']
        if '//' in ioc:
            return 'url', ioc
        else:
            return 'domain', ioc
    elif t == '#microsoft.graph.security.ipEvidence':
        return 'ip', evidence['ipAddress']
    elif t == '#microsoft.graph.security.fileEvidence':
        return 'hash', evidence.get('fileDetails', {}).get('sha256')
    return None, None


def get_args():
    """Gets arguments from the command line.

    Returns:
        dict: dict with the arguments and their values
    """

    parser = argparse.ArgumentParser(
        description="Recorded Future Microsoft Defender Insights Integration",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-k",
        "--key",
        dest="rf_api_key",
        help="Recorded Future API Key",
        default=os.environ.get("RF_API_KEY"),
        required=False,
    )
    parser.add_argument(
        "-cid",
        "--microsoft_client_id",
        dest="microsoft_client_id",
        help="Microsoft Client ID",
        default=os.environ.get("MS_CLIENT_ID"),
        required=False,
    )

    parser.add_argument(
        "-cs",
        "--microsoft_client_secret",
        dest="microsoft_client_secret",
        help="Microsoft Client Secret",
        default=os.environ.get("MS_CLIENT_SECRET"),
        required=False,
    )
    parser.add_argument(
        "-t",
        "--tenant",
        dest="tenant_id",
        help="Azure Tenant",
        default=os.environ.get("MS_TENANT"),
        required=False,
    )
    parser.add_argument(
        "-lb",
        "--lookback",
        dest="lookback",
        help="Lookback time for detections in days, defaults to 1",
        default=1,
        type=int,
        required=False,
    )
    parser.add_argument(
        "-fs",
        "--filter_string",
        dest="filter_string",
        help="Custom ODATA filter string to filter alerts, defaults to ''",
        default="",
        type=str,
        required=False,
    )
    parser.add_argument(
        "-l",
        "--loglevel",
        help="Log level, defaults to INFO",
        type=str,
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        default="INFO",
        dest="loglevel",
        required=False,
    )

    return parser.parse_args()



def main():
    args = get_args()
    LOG.level = logging.getLevelName(args.loglevel)
    LOG.info("Starting M365 Defender Collective Insights Integration")

    ms_client = MSFTClient(args.tenant_id, args.microsoft_client_id, args.microsoft_client_secret, args.lookback, args.filter_string)
    alerts = ms_client.list_v2_alerts()['value']
    LOG.info('Fetched {} alerts from Microsoft Defender'.format(len(alerts)))
    headers = {'X-RFToken': args.rf_api_key, 'User-Agent': '{}/{}'.format(APP_ID, APP_VERSION) }
    for al in alerts:
        data = []
        for ev in al['evidence']:
            type_, value = get_type(ev)
            if not type_ or not value:
                continue
            detection = {
                'timestamp': al['createdDateTime'],
                'detection' : {
                    'type': 'playbook',
                    'name': al['title']
                },
                'incident': {
                    'id': al['id'],
                    'type': al['serviceSource']

                },
                'ioc': {
                    'type': type_,
                    'value': value,
                    'source_type': al.get('detectionSource', 'unknown')
                },
                'mitre_codes': al['mitreTechniques']
            }
            data.append(detection)
        LOG.info('Posting IOCS for Incident {}'.format(al['id']))
        res = requests.post(
            COLLECTIVE_INSIGHTS_API_URL,
            headers=headers,
            json={'data': data})
        res.raise_for_status()
        iocs = ', '.join(d['ioc']['value'] for d in data)
        LOG.info('posted IOCS {}'.format(iocs))

if __name__ == '__main__':
    main()
