from datetime import timedelta
import airflow
from airflow import DAG
import json
import boto3
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.operators.python_operator import PythonVirtualenvOperator
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.providers.amazon.aws.operators.glue import GlueJobOperator
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.lambda_function import LambdaInvokeFunctionOperator


# Define DAG name
dag_name = 'cross_account_machine_learning'

# Unique identifier for the DAG
correlation_id = "{{ run_id }}"

# Default arguments for the DAG
default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': airflow.utils.dates.days_ago(1),
    'retries': 0,
    'retry_delay': timedelta(minutes=2),
    'provide_context': True,
}

# Initialize the DAG with default arguments
dag = DAG(
    dag_name,
    default_args=default_args,
    dagrun_timeout=timedelta(hours=2),
    schedule=None  # None means no regular schedule
)

# Parameterize S3 bucket names for flexibility
source_bucket_name = "<INSERT-MACHINE-LEARNING-BUCKET-NAME-US-WEST-2>"  # ML Pipeline bucket created by stackset "DPML_AccountB_Setup.yaml in us-west-2"
validation_data_key = 'xgboost/validate/validate.csv'  # Validation data S3 key
training_data_key = 'xgboost/train/train.csv'  # Training data S3 key

# S3 Sensor to validate the presence of validation data in the S3 bucket
s3_sensor_validate_data = S3KeySensor(
    task_id='s3_sensor_validate_data',
    bucket_name=source_bucket_name,  # Using the parameterized bucket name
    bucket_key=validation_data_key,  # Using the validation data key
    aws_conn_id='aws_crossaccount_role_conn_west2',  # Connection ID for cross-account role
    timeout=60 * 60,  # Timeout after 1 hour
    poke_interval=60,  # Check every 60 seconds
    dag=dag
)

# S3 Sensor to validate the presence of training data in the S3 bucket
s3_sensor_training_data = S3KeySensor(
    task_id='s3_sensor_training_data',
    bucket_name=source_bucket_name,  # Using the parameterized bucket name
    bucket_key=training_data_key,  # Using the training data key
    aws_conn_id='aws_crossaccount_role_conn_west2',  # Connection ID for cross-account role
    timeout=60 * 60,  # Timeout after 1 hour
    poke_interval=60,  # Check every 60 seconds
    dag=dag
)

# Custom SageMaker Hook for cross-account access
class CrossAccountSageMakerHook(SageMakerHook):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_conn(self):
        if self.conn is None:
            session = self.get_session()
            sts_client = session.client('sts')
            
            # Assuming role for cross-account access
            assumed_role_object = sts_client.assume_role(
                RoleArn=self.extra_config['role_arn'],
                RoleSessionName="AssumeRoleSession"
            )

            # Create session with assumed role credentials
            cross_account_session = boto3.Session(
                aws_access_key_id=assumed_role_object['Credentials']['AccessKeyId'],
                aws_secret_access_key=assumed_role_object['Credentials']['SecretAccessKey'],
                aws_session_token=assumed_role_object['Credentials']['SessionToken'],
                region_name=self.region_name
            )

            self.conn = cross_account_session.client('sagemaker')

        return self.conn

# Custom SageMaker Training Operator for cross-account access
class CrossAccountSageMakerTrainingOperator(SageMakerTrainingOperator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hook = CrossAccountSageMakerHook(aws_conn_id=self.aws_conn_id)

# Read the configuration from S3 (using the parameterized bucket name)
s3_hook = S3Hook(aws_conn_id='aws_crossaccount_role_conn_west2')
config = json.loads(s3_hook.read_key(bucket_name=source_bucket_name, key='task_storage/sagemaker_config.json'))[0]

# SageMaker training task using the cross-account operator
sagemaker_train_model_task = CrossAccountSageMakerTrainingOperator(
    task_id='sagemaker_train_model_task',
    dag=dag,
    config=config,
    aws_conn_id='aws_crossaccount_role_conn_west2',  # Use the cross-account connection
    wait_for_completion=True,
    check_interval=30
)

# Lambda function invocation task
invoke_lambda_function = LambdaInvokeFunctionOperator(
    task_id='invoke_lambda_function',
    function_name='lambda-cleanup',
    aws_conn_id='aws_crossaccount_role_conn_west2',
    invocation_type='RequestResponse',
    log_type='Tail',
    qualifier='$LATEST',
    dag=dag,
)

# Set the task dependencies in a sequential order
s3_sensor_validate_data >> s3_sensor_training_data >> sagemaker_train_model_task >> invoke_lambda_function
