#!/usr/bin/env python3
"""
将 GeoJSON 格式的 POI 数据转换为项目使用的 JSON 格式

输入格式 (GeoJSON):
{
  "type": "FeatureCollection",
  "features": [
    {
      "type": "Feature",
      "geometry": {"type": "Point", "coordinates": [lng, lat]},
      "properties": {"amenity": "restaurant", "shop": null, ...}
    }
  ]
}

输出格式:
[
  {"id": 1, "type": "Restaurant", "x": 162.16, "y": 40.32},
  ...
]
"""

import json
from collections import Counter

def extract_poi_type(properties):
    """从 properties 中提取 POI 类型"""
    # 优先级：amenity > shop > tourism > leisure > other
    if properties.get('amenity'):
        return properties['amenity']
    elif properties.get('shop'):
        return properties['shop']
    elif properties.get('tourism'):
        return properties['tourism']
    elif properties.get('leisure'):
        return properties['leisure']
    elif properties.get('office'):
        return 'office'
    elif properties.get('healthcare'):
        return 'healthcare'
    elif properties.get('education'):
        return 'education'
    else:
        return None

def normalize_poi_type(poi_type):
    """标准化 POI 类型名称（首字母大写）"""
    if not poi_type:
        return None
    # 处理下划线和连字符
    poi_type = poi_type.replace('_', ' ').replace('-', ' ')
    # 转换为标题格式（每个单词首字母大写）
    words = poi_type.split()
    return ' '.join(word.capitalize() for word in words)

def convert_geojson_to_json(geojson_path, output_path, use_projection=True):
    """
    转换 GeoJSON 到项目格式
    
    Args:
        geojson_path: 输入的 GeoJSON 文件路径
        output_path: 输出的 JSON 文件路径
        use_projection: 是否使用投影坐标（False 则使用经纬度）
    """
    print(f"读取 GeoJSON 文件: {geojson_path}")
    
    # 读取 GeoJSON
    with open(geojson_path, 'r', encoding='utf-8') as f:
        geojson_data = json.load(f)
    
    features = geojson_data.get('features', [])
    print(f"共读取 {len(features)} 个要素")
    
    # 收集所有有效 POI 的坐标，用于计算边界和归一化
    valid_coords = []
    poi_data = []
    
    for feature in features:
        # 提取 POI 类型
        properties = feature.get('properties', {})
        poi_type_raw = extract_poi_type(properties)
        
        if not poi_type_raw:
            continue
        
        # 获取坐标
        geometry = feature.get('geometry', {})
        if geometry.get('type') == 'Point':
            coordinates = geometry.get('coordinates', [])
            if len(coordinates) >= 2:
                lng, lat = coordinates[0], coordinates[1]
                poi_type = normalize_poi_type(poi_type_raw)
                poi_data.append((lng, lat, poi_type))
                valid_coords.append((lng, lat))
    
    if not valid_coords:
        print("错误：没有找到有效的 POI 数据")
        return
    
    # 计算边界
    min_lng = min(c[0] for c in valid_coords)
    max_lng = max(c[0] for c in valid_coords)
    min_lat = min(c[1] for c in valid_coords)
    max_lat = max(c[1] for c in valid_coords)
    
    print(f"\n坐标范围:")
    print(f"  经度: {min_lng:.6f} - {max_lng:.6f}")
    print(f"  纬度: {min_lat:.6f} - {max_lat:.6f}")
    
    # 计算缩放因子（将坐标映射到 0-200 范围，类似目标格式）
    lng_range = max_lng - min_lng
    lat_range = max_lat - min_lat
    
    # 使用较大的范围来保持宽高比
    max_range = max(lng_range, lat_range)
    scale = 200.0 / max_range if max_range > 0 else 1.0
    
    result = []
    poi_type_counter = Counter()
    skipped = 0
    target  = [
    "Restaurant",
    "Cafe",
    "Fast Food",
    "Convenience",
    "Bank",
    "Bakery",
    "Supermarket",
    "Bar",
    "Clothes",
    "Pharmacy",
    "Hairdresser",
    "Cinema",
    "School",
    "Hospital",
    "Post Office",
    "Police"
]
    for lng, lat, poi_type in poi_data:
        if poi_type in target:
            poi_type_counter[poi_type] += 1
            
            if use_projection:
                # 归一化到 0-200 范围
                x = (lng - min_lng) * scale
                y = (lat - min_lat) * scale
            else:
                # 直接使用经纬度
                x = lng
                y = lat
            
            result.append({
                "id": len(result) + 1,
                "type": poi_type,
                "x": round(x, 2),
                "y": round(y, 2)
            })
    
    skipped = len(features) - len(result)
    
    print(f"\n转换完成:")
    print(f"  - 成功转换: {len(result)} 个 POI")
    print(f"  - 跳过: {skipped} 个（无类型或非点要素）")
    print(f"\nPOI 类型统计（前10）:")
    for poi_type, count in poi_type_counter.most_common(10):
        print(f"  - {poi_type}: {count}")
    
    # 保存结果
    print(f"\n保存到: {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)
    
    print("转换完成！")

if __name__ == "__main__":
    import sys
    
    geojson_path = "/home/ubuntu/codebase/yexijia/保研/colocation_mvp/download/beijing_poi.geojson"
    output_path = "/home/ubuntu/codebase/yexijia/保研/colocation_mvp/data/beijing_poi.json"
    
    # 如果提供了命令行参数
    if len(sys.argv) > 1:
        geojson_path = sys.argv[1]
    if len(sys.argv) > 2:
        output_path = sys.argv[2]
    
    # 使用投影坐标（转换为近似米单位）
    convert_geojson_to_json(geojson_path, output_path, use_projection=True)

