from pyspark.context import SparkContext
from pyspark.sql import SparkSession

catalog_name = "rl_link_container_ordersdb"
aws_region = "us-west-2"
aws_account_id = "<<consumer_account_id>>"

warehouse_path = "file:///tmp/spark-warehouse"

link_container_catalog_name = "<<consumer_account_id>>:rl_link_container_ordersdb"
catalog_id = f"arn:aws:glue:{aws_region}:{link_container_catalog_name}"

orders_database = "public_db"
orders_table = "rl_orderstbl"

customer_database = "customerdb"
returns_table = "returnstbl_iceberg"

catalog_name1 = "spark_catalog"

spark = SparkSession.builder.appName('demo') \
    .config('spark.sql.extensions','org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions') \
    .config(f'spark.sql.catalog.{catalog_name1}', 'org.apache.iceberg.spark.SparkSessionCatalog') \
    .config(f'spark.sql.catalog.{catalog_name1}.catalog-impl', 'org.apache.iceberg.aws.glue.GlueCatalog') \
    .config(f'spark.sql.catalog.{catalog_name1}.client.region',aws_region) \
    .config(f'spark.sql.catalog.{catalog_name1}.glue.account-id', aws_account_id) \
    .config(f"spark.sql.catalog.{catalog_name1}.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") \
    .config(f'spark.sql.catalog.{catalog_name1}.warehouse',warehouse_path) \
    .config(f'spark.sql.catalog.{catalog_name}', 'org.apache.iceberg.spark.SparkCatalog') \
    .config(f'spark.sql.catalog.{catalog_name}.catalog-impl', 'org.apache.iceberg.aws.glue.GlueCatalog') \
    .config(f'spark.sql.catalog.{catalog_name}.glue.id', link_container_catalog_name) \
    .config(f'spark.sql.catalog.{catalog_name}.glue.account-id', aws_account_id) \
    .config(f'spark.sql.catalog.{catalog_name}.client.region', aws_region) \
    .config(f'spark.sql.catalog.{catalog_name}.glue.catalog-arn', catalog_id) \
    .getOrCreate()

spark.sql(f"show databases").show()

# Read from Federated table
spark.sql(f"SHOW TABLES IN {catalog_name}.{orders_database}").show()
spark.sql(f"DESCRIBE EXTENDED {catalog_name}.{orders_database}.{orders_table}").show()
spark.sql(f"SELECT ship_mode, count(1) as cnt FROM {catalog_name}.{orders_database}.{orders_table} group by ship_mode").show()

# Read from Iceberg table
spark.sql(f"SHOW TABLES IN {catalog_name1}.{customer_database}").show()
spark.sql(f"SELECT * FROM {catalog_name1}.{customer_database}.{returns_table}").show()


# Analysis by joining both the Federated and Iceberg table
spark.sql(f"""
    SELECT
        returns_tb.market as Market,
        sum(orders_tb.quantity) as Total_Quantity
    FROM {catalog_name}.{orders_database}.{orders_table} as orders_tb
    JOIN {catalog_name1}.{customer_database}.{returns_table} as returns_tb
        ON orders_tb.order_id = returns_tb.order_id
    GROUP BY returns_tb.market
""").show()



