from __future__ import annotations

import logging

import daft
from daft import col
from daft.las.functions.image.embedding.image_vit_embedding import ImageViTEmbedding
from daft.las.functions.image.image_resample import ImageResample
from daft.las.functions.udf import las_udf
from daft.las.io.tos import TOSConfig

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(),
    ],
)

logger = logging.getLogger(__name__)

def run_pipeline(input_tos_dir: str):
    input_s3_dir = input_tos_dir.replace("tos://", "s3://", 1)
    tos_config = TOSConfig.from_env()
    IO_CONFIG = daft.io.IOConfig(s3=tos_config.to_s3_config())
    model_path = "/opt/las/models"
    logger.info("图像检索pipeline开始运行...")
    
    df = daft.from_glob_path(
        f"{input_s3_dir}/*.png",
        io_config=IO_CONFIG,
    )
    
    df = df.with_column(
        "image_tos_path",
        col("path").str.replace("s3://", "tos://")
    )
    df = df.with_column(
        "image_id",
        col("path").str.split("/").list.get(-1).str.replace(".png", "")
    )
    df = df.select("image_tos_path", "image_id")
    
    logger.info("图像文件加载完成: %d 条记录", df.count_rows())
    
    logger.info("开始图像重采样...")
    df = df.with_column(
        "resampled_image",
        las_udf(
            ImageResample,
            construct_args={
                "image_suffix": ".png",
                "tos_dir": "tos://tos_bucket/image_retrieval/resample_image/",
                "local_dir": "",
                "image_src_type": "image_url",
                "target_size": [224, 224],
                "target_dpi": [72, 72],
                "method": "lanczos",
            },
            num_gpus=0,
            batch_size=32,
            concurrency=1,
        )(col("image_tos_path")),
    )
    df.collect()
    logger.info("图像重采样完成: %d 条记录", df.count_rows())
    
    logger.info("开始图像向量化...")
    df = df.with_column(
        "image_embedding",
        las_udf(
            ImageViTEmbedding,
            construct_args={
                "image_src_type": "image_base64",
                "dtype": "float32",
                "batch_size": 32,
                "model_path": model_path,
                "model_name": "facebook/dinov2-large",
                "use_cls_token_embedding": True,
                "rank": 0,
            },
            num_gpus=1,
            batch_size=32,
            concurrency=1,
        )(col("resampled_image")["base64"]),
    )
    df.collect()
    logger.info("图像向量化完成: %d 条记录", df.count_rows())
    
    df = df.with_column(
        "resampled_image_tos_path",
        col("resampled_image")["image_path"]
    )
    df = df.exclude("resampled_image")
    df.show()

if __name__ == "__main__":
    ray.init(dashboard_host="0.0.0.0")
    daft.context.set_runner_ray("ray://127.0.0.1:10001")
    input_tos_dir = "tos://tos_bucket/image_retrieval/test_image"
    run_pipeline(input_tos_dir)