"""
##################################### TERMS OF USE ###########################################
# The following code is provided for demonstration purpose only, and should not              #
# be used without independent verification. Recorded Future makes no representations         #
# or warranties, express, implied, statutory, or otherwise, regarding any aspect of          #
# this code or of the information it may retrieve, and provides it both strictly “as-is”     #
# and without assuming responsibility for any information it may retrieve. Recorded Future   #
# shall not be liable for, and you assume all risk of using, the foregoing. By using this    #
# code, Customer represents that it is solely responsible for having all necessary licenses, #
# permissions, rights, and/or consents to connect to third party APIs, and that it is solely #
# responsible for having all necessary licenses, permissions, rights, and/or consents to     #
# any data accessed from any third party API.                                                #
##############################################################################################
"""

import os
import sys
import time
import logging
import requests
import argparse
from datetime import datetime, timedelta
from logging.handlers import RotatingFileHandler
from requests import ConnectTimeout, ConnectionError, HTTPError, ReadTimeout

APP_ID = "CS-Collective_Insights"
APP_VERSION = "1.0.0"
COLLECTIVE_INSIGHTS_API_URL = "https://api.recordedfuture.com/collective-insights/detections"
CROWDSTRIKE_API_URL = "https://api.crowdstrike.com/"
OAUTH_TIMEOUT = 1800

class RecordedFutureLogger:
    """Rotating logger for capturing any diagnostic messages."""

    logging.basicConfig(filename="recordedfuture-cs-collectiveInsights.log", filemode="a")
    logger = logging.getLogger(__name__)
    file_handler = RotatingFileHandler(
        "recordedfuture-cs-collectiveInsights.log", maxBytes=10000000, backupCount=5
    )
    formatter_file = logging.Formatter(
        fmt="%(asctime)s,%(msecs)03d [%(threadName)s] %(levelname)s [%(module)s] "
        + "%(funcName)s:%(lineno)d - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    console_handler = logging.StreamHandler()
    formatter_console = logging.Formatter(
        fmt="%(asctime)s,%(msecs)03d %(levelname)s [%(module)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter_file)
    console_handler.setFormatter(formatter_console)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    logger.propagate = False


# Initialize global logger
LOG = RecordedFutureLogger.logger


class CrowdStrikeError(Exception):
    """Error raised when there is an issue with the CrowdStrike API."""


class RecordedFutureAPIError(Exception):
    """Error raised when there is an issue with the Recorded Future API."""


class RFScriptError(Exception):
    """Base class for other exceptions raised by this script."""


def get_crowdstrike_token(crowdstrike_api_url, crowdstrike_client_id, crowdstrike_client_secret):
    """Get CrowdStrike access token using a client ID and a client secret.

    Args:
        crowdstrike_api_url (str): CrowdStrike API URL
        crowdstrike_client_id (str): CrowdStrike Client ID
        crowdstrike_client_secret (str): CrowdStrike Client Secret

    Returns:
        str: CrowdStrike Access Token
        time: Time when the access token was retrieved
    """
    start_time = time.time()
    TOKEN_URL = crowdstrike_api_url + "oauth2/token"

    headers = {
        "Accept": "application/json",
        "Content-Type": "application/x-www-form-urlencoded"
    }

    data = {
        "client_id": crowdstrike_client_id,
        "client_secret": crowdstrike_client_secret
    }

    try:
        response = requests.post(TOKEN_URL, headers=headers, data=data)
        response.raise_for_status()
        response = response.json()

        access_token = response.get("access_token")
        LOG.info("CrowdStrike Access Token successfully retrieved")
    except (HTTPError, ConnectTimeout, ConnectionError, ReadTimeout, KeyError) as cse:
        LOG.error(f"Error getting CrowdStrike access token: {cse.response.text}")
        raise CrowdStrikeError(cse)

    return access_token, start_time


def get_crowdstrike_detection_ids(crowdstrike_api_url, crowdstrike_client_id, crowdstrike_client_secret, access_token, lookback_days, custom_filter, start_time):
    """Get CrowdStrike detection IDs.

    Args:
        crowdstrike_api_url (str): CrowdStrike API URL
        crowdstrike_client_id (str): CrowdStrike Client ID
        crowdstrike_client_secret (str): CrowdStrike Client Secret
        access_token (str): CrowdStrike Access Token
        lookback_days (int): Number of days to look back for detections
        custom_filter (str): Custom FQL filter to apply to the CrowdStrike detections query
        start_time (int): Time when the access token was retrieved

    Returns:
        [str]: A list of CrowdStrike detection IDs
    """
    IDS_URL = crowdstrike_api_url + "detects/queries/detects/v1"
    
    headers = {
        "Accept": "application/json",
        "Authorization": "Bearer " + access_token
    }
    
    # Create FQL filter string to get detections from a certain time period
    dt = (datetime.now() - timedelta(days=lookback_days)).isoformat()[:-7] + "Z"
    filter_string = f"last_behavior:>='{dt}'"

    # Append custom FQL string if provided by user
    if custom_filter != "":
        filter_string += f"+{custom_filter}"

    LOG.info(f"Using FQL filter for detection IDs query: {filter_string}")

    params = {
        "filter": filter_string,
        "offset": 0,
        "limit": 100
    }

    detection_ids = []

    try:
        while True:
            LOG.info(f"Fetching CrowdStrike detections data with offset: {params['offset']}")

            # Check if a new OAuth token is needed
            end_time = time.time()
            if end_time - start_time >= OAUTH_TIMEOUT:
                LOG.info("CrowdStrike OAuth token expired, fetching new token")
                access_token, start_time = get_crowdstrike_token(crowdstrike_api_url, crowdstrike_client_id, crowdstrike_client_secret)
                headers["Authorization"] = "Bearer " + access_token

            response = requests.get(IDS_URL, headers=headers, params=params)
            response.raise_for_status()
            response = response.json()

            if "resources" in response:
                detection_ids.extend(response["resources"])
            
            if "meta" in response and "pagination" in response["meta"]:
                pagination = response["meta"]["pagination"]
                total = pagination["total"]
                limit = pagination["limit"]
                offset = pagination["offset"]

                LOG.info(f"Fetched {str(len(detection_ids))} resources out of {total}")

                if offset + limit >= total:
                    LOG.info("All IDs fetched")
                    break

                params["offset"] = offset + limit

            else:
                LOG.info("No pagination information provided, all IDs fetched")
                break
    except (HTTPError, ConnectTimeout, ConnectionError, ReadTimeout, KeyError) as cse:
        LOG.error(f"Error getting CrowdStrike Detection IDs: {cse.response.text}")
        raise CrowdStrikeError(cse)

    return detection_ids


def get_crowdstrike_detections_data(crowdstrike_api_url, crowdstrike_client_id, crowdstrike_client_secret, access_token, detection_ids, start_time):
    """Get CrowdStrike detection data for the given detection IDs.

    Args:
        crowdstrike_api_url (str): CrowdStrike API URL
        crowdstrike_client_id (str): CrowdStrike Client ID
        crowdstrike_client_secret (str): CrowdStrike Client Secret
        access_token (str): CrowdStrike Access Token
        detection_ids ([str]): List of detection ID strings
        start_time (int): Time when the access token was retrieved

    Returns:
        list: A list of detection dicts

    """
    DETAILS_URL = crowdstrike_api_url + "detects/entities/summaries/GET/v1"

    headers = {
    "Authorization": "Bearer " + access_token,
    "Content-Type": "application/json",
    "Accept": "application/json"
    }
    
    num_detection_ids = len(detection_ids)

    # Create sublists of Detection IDs to send in batches of 1000
    lo_detection_ids = [detection_ids[i:i+1000] for i in range(0, num_detection_ids, 1000)]

    detections = []

    LOG.info(f"Fetching detection details for {num_detection_ids} IDs")
    try:
        # Check if a new OAuth token is needed
        end_time = time.time()
        if end_time - start_time >= OAUTH_TIMEOUT:
            LOG.info("CrowdStrike OAuth token expired, fetching new token")
            access_token, start_time = get_crowdstrike_token(crowdstrike_api_url, crowdstrike_client_id, crowdstrike_client_secret)
            headers["Authorization"] = "Bearer " + access_token

        for lo_detection_id in lo_detection_ids:
            response = requests.post(DETAILS_URL, headers=headers, json={"ids": lo_detection_id})
            response.raise_for_status()
            response = response.json()

            if "resources" in response:
                detections.extend(response["resources"])
                LOG.info(f"{len(detections)} of {num_detection_ids} detection details found")
            else:
                LOG.error("No 'resources' found in the response, could not fetch detection details")
    except (HTTPError, ConnectTimeout, ConnectionError, ReadTimeout, KeyError) as cse:
        LOG.error(f"Error getting CrowdStrike Detection Data: {cse}")
        raise CrowdStrikeError(cse.response.text)
    
    return detections


def build_collective_insights_payload(json_options, detections, lookback_days, severity, confidence):
    """Builds the JSON payload for data submission to Collective Insights

    Args:
        json_options (dict): Options for debugging and returned summaries
        detections (list): List of detection data dicts
        lookback_days (int): Number of days to look back for detections
        severity (int): Severity threshold for detections to include
        confidence (int): Confidence threshold for detections to include

    Returns:
        dict: JSON payload for Collective Insights submission
    """
    insights = []
    ioc_types = {
        "hash_sha256": "hash",
        "hash_md5": "hash",
        "sha256": "hash",
        "md5": "hash",
        "domain": "domain",
        "ipv4": "ip",
        "ipv6": "ip"
    }
    dt_lookback_days = (datetime.now() - timedelta(days=lookback_days)).isoformat()[:-7] + "Z"
    epoch_lookback_days = int(datetime.strptime(dt_lookback_days, "%Y-%m-%dT%H:%M:%SZ").timestamp())
    try:
        for detection in detections:
            for behavior in detection["behaviors"]:
                dt_behavior = behavior["timestamp"]
                epoch_behavior = int(datetime.strptime(dt_behavior, "%Y-%m-%dT%H:%M:%SZ").timestamp())
                if epoch_behavior < epoch_lookback_days:
                    continue
                if (behavior["ioc_type"] not in ioc_types) or (behavior["ioc_value"] == ""):
                    continue
                if (behavior["severity"] >= severity) and (behavior["confidence"] >= confidence):
                    cs_ioc_type = behavior["ioc_type"]
                    insight = {
                        "timestamp": behavior["timestamp"],
                        "ioc": {
                            "type": ioc_types.get(cs_ioc_type),
                            "value": behavior["ioc_value"],
                            "source_type": "CrowdStrike Falcon"
                        },
                        "incident": {
                            "id": detection["detection_id"],
                            "name": behavior["description"],
                            "type": "crowdstrike-falcon-threat-detection"
                        },
                        "detection": {
                            "name": behavior["description"],
                            "type": "correlation",
                            "sub_type": ""
                        }
                    }
                    insights.append(insight)
    except KeyError as ke:
        LOG.error(f"Error building Collective Insights payload: {ke}")
        raise RFScriptError(ke)


    LOG.info(f"Created {str(len(insights))} insights for payload submission")
    payload = dict(json_options, data=insights)

    return payload


def submit_collective_insights(collective_insights_upload, rf_api_key):
    """Submits data to Recorded Future Collective Insights API.

    Args:
        collective_insights_upload (dict): Collective Insights payload to submit
        rf_api_key (str): Recorded Future API Key

    Returns:
        response: response object for submission result
    """
    headers = {
        "User-Agent": f"{APP_ID}/{APP_VERSION}",
        "Content-Type": "application/json",
        "Accept": "application/json",
        "X-RFToken": f"{rf_api_key}",
    }

    # Assign data variable to the json payload submitted to this function
    data = collective_insights_upload
    LOG.debug(f"Submitting data to Recorded Future: {data}")

    # Post data to collective insights using arguments created above
    try:
        response = requests.post(COLLECTIVE_INSIGHTS_API_URL, headers=headers, json=data)
        response.raise_for_status()

    except (HTTPError, ConnectTimeout, ConnectionError, ReadTimeout) as err:
        LOG.error(f"Error submitting to Collective Insights: {err.response.text}")
        raise RecordedFutureAPIError(err)

    return response


def get_args():
    """Gets arguments from the command line.

    Returns:
        dict: dict with the arguments and their values
    """

    parser = argparse.ArgumentParser(
        description="Recorded Future CrowdStrike Collective Insights Integration",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-k",
        "--key",
        dest="rf_api_key",
        help="Recorded Future API Key",
        default=os.environ.get("RF_API_KEY"),
        required=False,
    )
    parser.add_argument(
        "-ccid",
        "--crowdstrike_client_id",
        dest="crowdstrike_client_id",
        help="CrowdStrike Client ID",
        default=os.environ.get("CS_CLIENT_ID"),
        required=False,
    )
    parser.add_argument(
        "-ccs",
        "--crowdstrike_client_secret",
        dest="crowdstrike_client_secret",
        help="CrowdStrike Client Secret",
        default=os.environ.get("CS_CLIENT_SECRET"),
        required=False,
    )
    parser.add_argument(
        "-lb",
        "--lookback",
        dest="lookback",
        help="Lookback time for detections in days, defaults to 1",
        type=int,
        required=False,
    )
    parser.add_argument(
        "-s",
        "--severity",
        dest="severity",
        help="Behavior severity threshold for insights, defaults to 0",
        type=int,
        required=False,
    )
    parser.add_argument(
        "-c",
        "--confidence",
        dest="confidence",
        help="Behavior confidence threshold for insights, defaults to 0",
        type=int,
        required=False,
    )
    parser.add_argument(
        "-fs",
        "--filter_string",
        dest="filter_string",
        help="Custom FQL filter string to filter detections, defaults to ''",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--debug",
        dest="debug_flag",
        help="Debug mode, insights not posted to Recorded Future, defaults to False",
        type=bool,
        required=False,
    )
    parser.add_argument(
        "-l",
        "--loglevel",
        help="Log level, defaults to INFO",
        type=str,
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        default="INFO",
        dest="loglevel",
        required=False,
    )
    parser.set_defaults(lookback=1, severity=0, confidence=0, filter_string="", debug_flag=False)

    return parser.parse_args()


def main():
    """Main function that calls other functions to query CrowdStrike and submit the results
    to the Recorded Future Collective Insights API."""
    try:
        # Verify all necessary arguments are present
        args = get_args()

        if any(x is None for x in args.__dict__.values()):
            LOG.error("Args not found. Exiting program...")
            raise RFScriptError("Args not found. Exiting program...")
    except(argparse.ArgumentError, RFScriptError):
        LOG.error("Error parsing arguments", exc_info=True)
        sys.exit(1)

    # Set up CI options
    json_options = {"options": {"debug": args.debug_flag, "summary": True}}

    # Set up logger
    LOG.level = logging.getLevelName(args.loglevel)

    LOG.info("Starting CrowdStrike Collective Insights Integration")

    try:
        # Get CrowdStrike detections data
        cs_access_token, start_time = get_crowdstrike_token(CROWDSTRIKE_API_URL, args.crowdstrike_client_id, args.crowdstrike_client_secret)
        detection_ids = get_crowdstrike_detection_ids(CROWDSTRIKE_API_URL, args.crowdstrike_client_id, args.crowdstrike_client_secret, cs_access_token, args.lookback, args.filter_string, start_time)
        detections = get_crowdstrike_detections_data(CROWDSTRIKE_API_URL, args.crowdstrike_client_id, args.crowdstrike_client_secret, cs_access_token, detection_ids, start_time)

        # Build payload for Collective Insights
        try:
            payload = build_collective_insights_payload(json_options, detections, args.lookback, args.severity, args.confidence)
        except Exception as e:
            LOG.error(f"Error building Collective Insights payload: {e}")
            raise RFScriptError(e)
        
        # Send payload to Collective Insights
        try: 
            response = submit_collective_insights(payload, args.rf_api_key)
            LOG.info(response.text)
            LOG.info("Script completed successfully")
        except (HTTPError, ConnectTimeout, ConnectionError, ReadTimeout, KeyError) as e:
            LOG.error(f"Error submitting to Collective Insights: {e}")
            sys.exit(1)

    except CrowdStrikeError:
        LOG.error(f"Error querying CrowdStrike", exc_info=True)
        sys.exit(1)


if __name__ == "__main__":
    main()