from __future__ import annotations

import logging
import os
import re
from datetime import datetime
from typing import List, Dict, Any

import daft
from daft import col
from daft.las.functions.video.video_extract_audio import VideoExtractAudio
from daft.las.functions.audio.audio_asr_doubao import AudioAsrDoubao
from daft.las.functions.text.pre_sign_url_for_tos import PreSignUrlForTos
from daft.las.functions.ark_llm.ark_llm_thinking_vision import ArkLLMThinkingVision
from daft.las.functions.udf import las_udf
from daft.las.io.tos import TOSConfig
from daft.las.functions.types import Operator
from daft.dependencies import pa

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class AsrTextParser(Operator):
    def parse_asr_segments(self, asr_text_simple: str) -> List[Dict[str, Any]]:
        if not asr_text_simple or asr_text_simple.strip() == "None":
            return []
        
        segments = []
        speaker_blocks = asr_text_simple.strip().split('\n\n')
        
        for block in speaker_blocks:
            lines = block.strip().split('\n')
            if len(lines) < 2:
                continue
                
            header_line = lines[0]
            text_content = '\n'.join(lines[1:]).strip()
            
            pattern = r'说话人\s+(\d+)\s+(\d+:\d+:\d+)\s+(\d+:\d+:\d+)'
            match = re.match(pattern, header_line)
            
            if match:
                speaker_id = match.group(1)
                start_time = match.group(2)
                end_time = match.group(3)
                
                segments.append({
                    "speaker": f"说话人 {speaker_id}",
                    "start_time": start_time,
                    "end_time": end_time,
                    "text": text_content
                })
        
        return segments
    
    def transform(self, asr_texts: pa.Array) -> pa.Array:
        results = []
        for asr_text in asr_texts.to_pylist():
            parsed_segments = self.parse_asr_segments(asr_text)
            results.append(parsed_segments)
        
        return pa.array(results, type=self.__return_column_type__())
    
    @staticmethod
    def __return_column_type__() -> pa.DataType:
        return pa.list_(pa.struct([
            pa.field("speaker", pa.string()),
            pa.field("start_time", pa.string()),
            pa.field("end_time", pa.string()),
            pa.field("text", pa.string())
        ]))

class VlmPromptGenerator(Operator):
    def time_str_to_seconds(self, time_str: str) -> float:
        parts = time_str.split(':')
        hours = int(parts[0])
        minutes = int(parts[1])
        seconds = int(parts[2])
        return hours * 3600 + minutes * 60 + seconds
    
    def generate_vlm_prompt(self, asr_segments: List[Dict[str, Any]]) -> str:
        subtitles = []
        timestamps = []
        
        for segment in asr_segments:
            subtitles.append(segment['text'])
            start_seconds = self.time_str_to_seconds(segment['start_time'])
            end_seconds = self.time_str_to_seconds(segment['end_time'])
            timestamps.append([start_seconds, end_seconds])
        
        original_prompt = """你是一名视频字幕核对助手，负责核对视频声音和字幕是否同步。具体任务是结合ASR提取的一一对应的视频字幕以及时间戳列表，对比分析视频画面中的字幕时间是否和结构化视频字幕的开始时间和结束时间同步，并记录分析结果。

如下ASR 字幕示例，字幕 "很高兴你能过来" 开始时间是 4.8 秒，结束的时间是 7.46 秒。 
<ASR 字幕示例>
[……,"很高兴你能过来",……]
</ASR 字幕示例>
<ASR 字幕时间戳示例>
[……,[4.8,7.46],……]
</ASR 字幕时间戳示例>

首先，请仔细阅读以下ASR提取的视频字幕列表、字幕起始和终止时间戳列表：
<ASR_list_subtitles>
{{ASR_LIST_SUBTITLES}}
</ASR_list_subtitles>
<ASR_list_timestamps>
{{ASR_LIST_TIMESTAMPS}}
</ASR_list_timestamps>
接下来，请仔细分析视频画面中的字幕，按照以下步骤进行：
1. 逐一对视频画面中的字幕与结构化的语音字幕逐一进行匹配。
2. 对比每一条匹配字幕的开始时间和结束时间。
3. 若开始时间和结束时间一致或在可接受的误差范围内（如±0.5秒），则判定为同步；否则判定为不同步。
4. 记录每一条字幕的核对结果，如果不同步，在结果 "notes" 字段说明视频画面字幕的延迟出现，还是提前出现，并记录具体的时间，。

请结合结构化的视频字幕和时间戳列表仔细分析视频，然后逐条给出视频字幕和语音字幕匹配检测的结果，结果格式为 JSON，不做额外解释。
其中，start_time为语音字幕开始时间，end_time为语音字幕结束时间，text为语音字幕，vlm为检索结果，如下为输出示例：
[……,
  {
    "start_time": 4.8,
    "end_time": 7.46,
    "text": "很高兴你能过来",
    "vlm": {"是否同步":"<是/否>", "notes":"<视频画面字幕的延迟出现/提前出现> X 秒"}
  },
……
]"""
        
        prompt = original_prompt.replace("{{ASR_LIST_SUBTITLES}}", str(subtitles))
        prompt = prompt.replace("{{ASR_LIST_TIMESTAMPS}}", str(timestamps))
        
        return prompt
    
    def transform(self, asr_segments_list: pa.Array) -> pa.Array:
        results = []
        for asr_segments in asr_segments_list.to_pylist():
            prompt = self.generate_vlm_prompt(asr_segments)
            results.append(prompt)
        
        return pa.array(results, type=self.__return_column_type__())
    
    @staticmethod
    def __return_column_type__() -> pa.DataType:
        return pa.string()

def run_pipeline(input_tos_dir: str, audio_output_tos_dir: str):
    input_s3_dir = input_tos_dir.replace("tos://", "s3://", 1)
    tos_config = TOSConfig.from_env()
    IO_CONFIG = daft.io.IOConfig(s3=tos_config.to_s3_config())
    
    logger.info("视频字幕同步检查pipeline开始运行...")
    
    df = daft.from_glob_path(
        f"{input_s3_dir}/*.mp4",
        io_config=IO_CONFIG,
    )
    
    df = df.with_column(
        "video_tos_path",
        col("path").str.replace("s3://", "tos://")
    )
    df = df.with_column(
        "video_id",
        col("path").str.split("/").list.get(-1).str.replace(".mp4", "")
    )
    df = df.select("video_tos_path", "video_id")
    
    logger.info("视频文件加载完成: %d 条记录", df.count_rows())
    
    logger.info("开始视频音频抽取...")
    df = df.with_column(
        "audio_extraction_result",
        las_udf(
            VideoExtractAudio,
            construct_args={
                "output_tos_dir": audio_output_tos_dir,
                "output_audio_binary": True,
                "output_audio_array": False,
                "return_first_stream": True,       
                "output_format": "mp3",
                "output_sample_rate": 16000,
            },
            num_gpus=0,
            batch_size=2,
            concurrency=1,
        )(col("video_tos_path")),
    )
    df.collect()
    logger.info("视频音频抽取完成: %d 条记录", df.count_rows())

    df = df.with_columns({
        "audio_tos_path": col("audio_extraction_result").struct.get("audio_paths").list.get(0),
        "audio_binary": col("audio_extraction_result").struct.get("binaries").list.get(0),
        "original_sample_rate": col("audio_extraction_result").struct.get("original_audio_sampling_rates").list.get(0)
    })
    
    logger.info("开始音频URL预签名...")
    df = df.with_column(
        "audio_presigned_url",
        las_udf(
            PreSignUrlForTos,
            construct_args={"expire_seconds": 3600},
        )(col("audio_tos_path")),
    )
    df.collect()
    logger.info("音频URL预签名完成: %d 条记录", df.count_rows())
    
    logger.info("开始音频ASR识别...")
    appid = os.getenv("OPENSPEECH_APPID")
    token = os.getenv("OPENSPEECH_TOKEN")
    
    if not appid or not token:
        logger.error("缺少OPENSPEECH_APPID或OPENSPEECH_TOKEN环境变量")
        return
    
    df = df.with_column(
        "asr_result",
        las_udf(
            AudioAsrDoubao,
            construct_args={
                "appid": appid,
                "token": token,
                "uid": "video_subtitle_sync",
                "enable_speaker_info": True,
                "enable_punc": True,
                "enable_ddc": True,
                "poll_interval": 10,
                "num_coroutines": 1,
            },
            num_gpus=0,
            batch_size=1,
            concurrency=1,
        )(col("audio_presigned_url")),
    )
    df.collect()
    logger.info("音频ASR识别完成: %d 条记录", df.count_rows())
    
    df = df.with_columns({
        "asr_text_simple": col("asr_result").struct.get("asr_result_simple"),
        "asr_text_raw": col("asr_result").struct.get("asr_result_raw"),
        "asr_text_content": col("asr_result").struct.get("asr_result_text")
    })
    
    logger.info("开始ASR文本结构化解析...")
    df = df.with_column(
        "asr_segments_structured",
        las_udf(AsrTextParser)(col("asr_text_simple")),
    )
    df.collect()
    logger.info("ASR文本结构化解析完成: %d 条记录", df.count_rows())
    
    logger.info("开始生成VLM提示词...")
    df = df.with_column(
        "vlm_prompt",
        las_udf(VlmPromptGenerator)(col("asr_segments_structured")),
    )
    df.collect()
    logger.info("VLM提示词生成完成: %d 条记录", df.count_rows())
    
    logger.info("开始视频字幕同步分析...")
    
    df = df.with_column(
        "vlm_analysis_struct",
        las_udf(
            ArkLLMThinkingVision,
            construct_args={
                "model": "doubao-1.5-thinking-vision-pro",
                "version":"250428",
                "multimodal_type": "video",
                "inference_type": "online",
                "source_type": "url",
                "video_format": "mp4",
                "video_fps": 1.0,
                "max_tokens": 4000,
                "temperature": 0.1,
            },
            num_gpus=0,
            batch_size=1,
            concurrency=1,
        )(col("video_tos_path"), col("vlm_prompt")),
    )
    df.collect()
    logger.info("视频字幕同步分析完成: %d 条记录", df.count_rows())
    
    df = df.with_column(
        "vlm_analysis_result",
        col("vlm_analysis_struct").struct.get("llm_result")
    )
    
    df_clean = df.exclude("audio_extraction_result", "audio_binary", "asr_result", "asr_text_raw","audio_presigned_url","original_sample_rate","asr_segments_structured","vlm_prompt","vlm_analysis_struct")
    df_clean.show()

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_tos_dir = "tos://tos_bucket/video_subtitle_sync/results/"
    output_s3_dir = output_tos_dir.replace("tos://", "s3://", 1)
    
    tos_config = TOSConfig.from_env()
    io_config = daft.io.IOConfig(s3=tos_config.to_s3_config())
    
    parquet_s3_path = f"{output_s3_dir}video_subtitle_sync_parquet_{timestamp}"
    
    df_clean.write_parquet(parquet_s3_path, io_config=io_config)
    
    logger.info("视频字幕同步结果已保存到TOS:")
    logger.info("Parquet: %s", output_tos_dir + f"video_subtitle_sync_parquet_{timestamp}")

if __name__ == "__main__":
    input_tos_dir = "tos://tos_bucket/video_subtitle_sync/test_video"
    audio_output_tos_dir = "tos://tos_bucket/video_subtitle_sync/video_audio_extracted/"
    
    run_pipeline(input_tos_dir, audio_output_tos_dir)