#!/usr/bin/env python3
"""
A* 路径规划 + RViz2 可视化 (修复版)
修复：确保地图加载后再规划
"""

import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
from nav_msgs.msg import Path, OccupancyGrid
from geometry_msgs.msg import PoseStamped, PoseWithCovarianceStamped
import numpy as np
import heapq


class AStarRVizFixed(Node):
    def __init__(self):
        super().__init__('astar_rviz_fixed')
        
        map_qos = QoSProfile(
            reliability=ReliabilityPolicy.RELIABLE,
            durability=DurabilityPolicy.TRANSIENT_LOCAL,
            depth=1
        )
        
        self.map_sub = self.create_subscription(
            OccupancyGrid, '/map', self.map_callback, map_qos)
        self.start_sub = self.create_subscription(
            PoseWithCovarianceStamped, '/initialpose', self.start_callback, 10)
        self.goal_sub = self.create_subscription(
            PoseStamped, '/goal_pose', self.goal_callback, 10)
        self.path_pub = self.create_publisher(Path, '/astar_path', 10)
        
        self.map_data = None
        self.map_info = None
        self.start_pose = None
        self.goal_pose = None
        self.has_start = False
        self.has_goal = False
        self.map_ready = False
        
        self.get_logger().info('='*50)
        self.get_logger().info('A* Planner 修复版启动！')
        self.get_logger().info('等待地图加载...')
        self.get_logger().info('='*50)

    def map_callback(self, msg):
        if self.map_data is not None:
            return  # 只加载一次
        self.map_info = msg.info
        width = msg.info.width
        height = msg.info.height
        self.map_data = np.array(msg.data).reshape((height, width))
        self.map_ready = True
        self.get_logger().info(f'地图加载完成: {width}x{height}')
        self.get_logger().info('现在可以设置起点和终点了')
        # 如果起点终点都已设置，立即规划
        if self.has_start and self.has_goal:
            self.plan_path()

    def start_callback(self, msg):
        self.start_pose = (msg.pose.pose.position.x, msg.pose.pose.position.y)
        self.has_start = True
        self.get_logger().info(f'设置起点: {self.start_pose}')
        if not self.map_ready:
            self.get_logger().warn('地图还在加载中，请稍候...')
            return
        if self.has_goal:
            self.plan_path()
        else:
            self.get_logger().info('请设置终点（2D Goal Pose）')

    def goal_callback(self, msg):
        self.goal_pose = (msg.pose.position.x, msg.pose.position.y)
        self.has_goal = True
        self.get_logger().info(f'设置终点: {self.goal_pose}')
        if not self.map_ready:
            self.get_logger().warn('地图还在加载中，请稍候...')
            return
        if self.has_start:
            self.plan_path()
        else:
            self.get_logger().info('请设置起点（2D Pose Estimate）')

    def world_to_grid(self, x, y):
        if self.map_info is None:
            return None
        gx = int((x - self.map_info.origin.position.x) / self.map_info.resolution)
        gy = int((y - self.map_info.origin.position.y) / self.map_info.resolution)
        return (gy, gx)

    def grid_to_world(self, row, col):
        x = col * self.map_info.resolution + self.map_info.origin.position.x
        y = row * self.map_info.resolution + self.map_info.origin.position.y
        return (x, y)

    def heuristic(self, a, b):
        return np.sqrt((a[0]-b[0])**2 + (a[1]-b[1])**2)

    def astar(self, start, goal):
        if self.map_data is None:
            self.get_logger().error('地图未加载，无法规划')
            return None
        rows, cols = self.map_data.shape
        
        if not (0 <= start[0] < rows and 0 <= start[1] < cols):
            self.get_logger().error(f'起点越界: {start}')
            return None
        if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
            self.get_logger().error(f'终点越界: {goal}')
            return None
        
        start_val = self.map_data[start[0], start[1]]
        goal_val = self.map_data[goal[0], goal[1]]
        self.get_logger().info(f'起点值: {start_val}, 终点值: {goal_val}')
        
        if start_val > 50:
            self.get_logger().error(f'起点在障碍物中')
            return None
        if goal_val > 50:
            self.get_logger().error(f'终点在障碍物中')
            return None
        
        open_set = []
        heapq.heappush(open_set, (0, start))
        came_from = {}
        g_score = {start: 0}
        neighbors = [(0,1), (1,0), (0,-1), (-1,0), (1,1), (1,-1), (-1,1), (-1,-1)]
        
        while open_set:
            _, current = heapq.heappop(open_set)
            if current == goal:
                path = [current]
                while current in came_from:
                    current = came_from[current]
                    path.append(current)
                return path[::-1]
            
            for dr, dc in neighbors:
                neighbor = (current[0]+dr, current[1]+dc)
                nr, nc = neighbor
                if not (0 <= nr < rows and 0 <= nc < cols):
                    continue
                if self.map_data[nr, nc] > 50:
                    continue
                dist = 1.414 if dr != 0 and dc != 0 else 1.0
                tentative_g = g_score[current] + dist
                if neighbor not in g_score or tentative_g < g_score[neighbor]:
                    came_from[neighbor] = current
                    g_score[neighbor] = tentative_g
                    f = tentative_g + self.heuristic(neighbor, goal)
                    heapq.heappush(open_set, (f, neighbor))
        self.get_logger().warn('A* 搜索完成，未找到路径')
        return None

    def plan_path(self):
        if not self.has_start or not self.has_goal:
            self.get_logger().error('起点或终点未设置')
            return
        if not self.map_ready:
            self.get_logger().error('地图未加载')
            return
        
        start = self.start_pose
        goal = self.goal_pose
        start_grid = self.world_to_grid(*start)
        goal_grid = self.world_to_grid(*goal)
        
        if start_grid is None or goal_grid is None:
            self.get_logger().error('坐标转换失败')
            return
        
        self.get_logger().info('='*50)
        self.get_logger().info(f'开始规划: {start} -> {goal}')
        self.get_logger().info(f'栅格坐标: {start_grid} -> {goal_grid}')
        
        path_grid = self.astar(start_grid, goal_grid)
        
        if path_grid:
            path_msg = Path()
            path_msg.header.stamp = self.get_clock().now().to_msg()
            path_msg.header.frame_id = 'map'
            for row, col in path_grid:
                x, y = self.grid_to_world(row, col)
                pose = PoseStamped()
                pose.header = path_msg.header
                pose.pose.position.x = x
                pose.pose.position.y = y
                pose.pose.position.z = 0.0
                pose.pose.orientation.w = 1.0
                path_msg.poses.append(pose)
            self.path_pub.publish(path_msg)
            self.get_logger().info(f'路径发布成功: {len(path_grid)} 个点')
            self.get_logger().info('='*50)
        else:
            self.get_logger().error('路径规划失败')


def main():
    rclpy.init()
    node = AStarRVizFixed()
    rclpy.spin(node)
    node.destroy_node()
    rclpy.shutdown()


if __name__ == '__main__':
    main()
