import rclpy
from rclpy.node import Node
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Imu
from geometry_msgs.msg import Quaternion
import numpy as np
from tf_transformations import quaternion_from_euler, euler_from_quaternion

class EKFNode(Node):
    def __init__(self):
        super().__init__("ekf_fusion_node")

        # 订阅
        self.imu_sub = self.create_subscription(Imu, "/imu/data", self.imu_callback, 10)
        self.lidar_sub = self.create_subscription(Odometry, "/odom", self.lidar_callback, 10)

        # 发布融合结果
        self.odom_pub = self.create_publisher(Odometry, "/odom_ekf", 10)

        # EKF 状态：x, y, yaw, vx, vy, vyaw
        self.X = np.zeros(6)
        self.P = np.eye(6) * 0.1
        self.last_time = self.get_clock().now()

        self.get_logger().info("✅ EKF 融合节点已启动（Python完整版）")

    def imu_callback(self, msg):
        # IMU 做 预测
        current_time = self.get_clock().now()
        dt = (current_time - self.last_time).nanoseconds / 1e9
        self.last_time = current_time

        yaw_vel = msg.angular_velocity.z
        ax = msg.linear_acceleration.x
        ay = msg.linear_acceleration.y

        self.predict(dt, yaw_vel, ax, ay)

    def lidar_callback(self, msg):
        # 雷达做 更新
        x = msg.pose.pose.position.x
        y = msg.pose.pose.position.y

        q = msg.pose.pose.orientation
        _, _, yaw = euler_from_quaternion([q.x, q.y, q.z, q.w])

        self.update(x, y, yaw)
        self.publish_odom()

    def predict(self, dt, wy, ax, ay):
        x, y, yaw, vx, vy, wy_ = self.X

        x += vx * dt
        y += vy * dt
        yaw += wy * dt
        vx += ax * dt
        vy += ay * dt

        self.X = np.array([x, y, yaw, vx, vy, wy])

        F = np.eye(6)
        F[0,3] = dt
        F[1,4] = dt
        F[2,5] = dt
        F[3,3] = 1
        F[4,4] = 1

        Q = np.eye(6) * 0.01
        self.P = F @ self.P @ F.T + Q

    def update(self, x, y, yaw):
        H = np.zeros((3,6))
        H[0,0] = 1
        H[1,1] = 1
        H[2,2] = 1

        Z = np.array([x, y, yaw])
        HX = H @ self.X
        Y = Z - HX

        R = np.eye(3) * 0.1
        S = H @ self.P @ H.T + R
        K = self.P @ H.T @ np.linalg.inv(S)

        self.X = self.X + K @ Y
        self.P = (np.eye(6) - K @ H) @ self.P

    def publish_odom(self):
        msg = Odometry()
        msg.header.stamp = self.get_clock().now().to_msg()
        msg.header.frame_id = "odom"
        msg.child_frame_id = "base_footprint"

        msg.pose.pose.position.x = self.X[0]
        msg.pose.pose.position.y = self.X[1]

        q = quaternion_from_euler(0, 0, self.X[2])
        msg.pose.pose.orientation = Quaternion(x=q[0], y=q[1], z=q[2], w=q[3])

        self.odom_pub.publish(msg)

def main(args=None):
    rclpy.init(args=args)
    node = EKFNode()
    rclpy.spin(node)
    rclpy.shutdown()

if __name__ == "__main__":
    main()