#!/usr/bin/env python3
"""
A* 路径规划 + RViz2 可视化
"""

import rclpy
from rclpy.node import Node
from nav_msgs.msg import Path, OccupancyGrid
from geometry_msgs.msg import PoseStamped
import numpy as np
import heapq


class AStarRViz(Node):
    def __init__(self):
        super().__init__('astar_rviz')
        
        # 订阅地图
        self.map_sub = self.create_subscription(
            OccupancyGrid, '/map', self.map_callback, 10)
        
        # 发布路径
        self.path_pub = self.create_publisher(Path, '/astar_path', 10)
        
        # 定时器：每2秒运行一次规划
        self.timer = self.create_timer(2.0, self.plan_path)
        
        self.map_data = None
        self.map_info = None
        
        # 起点终点（世界坐标）
        self.start = (0.0, 0.0)      # 根据实际情况修改
        self.goal = (3.0, 4.0)       # 根据实际情况修改
        
        self.get_logger().info('A* Planner 启动，等待地图...')

    def map_callback(self, msg):
        """接收地图"""
        self.map_info = msg.info
        width = msg.info.width
        height = msg.info.height
        
        # 转换为一维数组
        self.map_data = np.array(msg.data).reshape((height, width))
        self.map_data = np.flipud(self.map_data)  # ROS地图原点在左下角
        
        self.get_logger().info(f'地图接收: {width}x{height}')

    def world_to_grid(self, x, y):
        """世界坐标 -> 栅格坐标"""
        if self.map_info is None:
            return None
        gx = int((x - self.map_info.origin.position.x) / self.map_info.resolution)
        gy = int((y - self.map_info.origin.position.y) / self.map_info.resolution)
        return (gy, gx)  # 注意：numpy是(row, col)

    def grid_to_world(self, row, col):
        """栅格坐标 -> 世界坐标"""
        x = col * self.map_info.resolution + self.map_info.origin.position.x
        y = row * self.map_info.resolution + self.map_info.origin.position.y
        return (x, y)

    def heuristic(self, a, b):
        """曼哈顿距离"""
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

    def astar(self, start, goal):
        """A* 算法"""
        if self.map_data is None:
            return None
            
        rows, cols = self.map_data.shape
        
        # 检查起点终点有效性
        if not (0 <= start[0] < rows and 0 <= start[1] < cols):
            return None
        if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
            return None
        
        open_set = []
        heapq.heappush(open_set, (0, start))
        came_from = {}
        g_score = {start: 0}
        
        while open_set:
            _, current = heapq.heappop(open_set)
            
            if current == goal:
                # 回溯路径
                path = [current]
                while current in came_from:
                    current = came_from[current]
                    path.append(current)
                return path[::-1]
            
            # 4方向邻居
            for dr, dc in [(0,1), (1,0), (0,-1), (-1,0)]:
                neighbor = (current[0]+dr, current[1]+dc)
                nr, nc = neighbor
                
                # 边界检查
                if not (0 <= nr < rows and 0 <= nc < cols):
                    continue
                
                # 障碍物检查（ROS: 100=占用, -1=未知, 0=空闲）
                if self.map_data[nr, nc] > 50:
                    continue
                
                tentative_g = g_score[current] + 1
                
                if neighbor not in g_score or tentative_g < g_score[neighbor]:
                    came_from[neighbor] = current
                    g_score[neighbor] = tentative_g
                    f = tentative_g + self.heuristic(neighbor, goal)
                    heapq.heappush(open_set, (f, neighbor))
        
        return None

    def plan_path(self):
        """规划并发布路径"""
        if self.map_data is None:
            self.get_logger().warn('等待地图...')
            return
        
        # 转换起点终点
        start_grid = self.world_to_grid(*self.start)
        goal_grid = self.world_to_grid(*self.goal)
        
        if start_grid is None or goal_grid is None:
            return
        
        self.get_logger().info(f'规划: {self.start} -> {self.goal}')
        
        # 运行 A*
        path_grid = self.astar(start_grid, goal_grid)
        
        if path_grid:
            # 转换为 ROS Path 消息
            path_msg = Path()
            path_msg.header.stamp = self.get_clock().now().to_msg()
            path_msg.header.frame_id = 'map'
            
            for row, col in path_grid:
                x, y = self.grid_to_world(row, col)
                pose = PoseStamped()
                pose.header = path_msg.header
                pose.pose.position.x = x
                pose.pose.position.y = y
                pose.pose.position.z = 0.0
                pose.pose.orientation.w = 1.0
                path_msg.poses.append(pose)
            
            self.path_pub.publish(path_msg)
            self.get_logger().info(f'路径发布: {len(path_grid)} 个点')
        else:
            self.get_logger().warn('未找到路径')


def main():
    rclpy.init()
    node = AStarRViz()
    rclpy.spin(node)
    node.destroy_node()
    rclpy.shutdown()


if __name__ == '__main__':
    main()
