package controllers

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/bson/primitive"
	"go.mongodb.org/mongo-driver/mongo/options"
	"log"
	"streamServer/database"
	"sync"
)

func getLastCheckpoint(schemaName, tableName string) string {
	var lastID string
	query := "SELECT last_mongo_id FROM sync_status WHERE schema_name=$1 AND table_name=$2"
	err := database.PgDB.QueryRow(query, schemaName, tableName).Scan(&lastID)
	if err != nil && err != sql.ErrNoRows {
		log.Printf("获取最后同步ID失败: %v", err)
	}
	return lastID
}

func updateCheckpoint(schemaName, tableName, lastMongoID string) {
	query := `
    INSERT INTO sync_status (schema_name, table_name, last_mongo_id)
    VALUES ($1, $2, $3)
    ON CONFLICT (schema_name, table_name) DO UPDATE SET last_mongo_id = EXCLUDED.last_mongo_id
    `
	_, err := database.PgDB.Exec(query, schemaName, tableName, lastMongoID)
	if err != nil {
		log.Printf("更新同步标记失败: %v", err)
	}
}

func convertObjectIDToStr(document *map[string]interface{}) {
	for key, value := range *document {
		switch v := value.(type) {
		case map[string]interface{}:
			convertObjectIDToStr(&v)
		case []interface{}:
			for i := range v {
				if nestedMap, ok := v[i].(map[string]interface{}); ok {
					convertObjectIDToStr(&nestedMap)
				}
			}
		case primitive.ObjectID:
			(*document)[key] = v.Hex()
		}
	}
}

func dealJSONDeleteID(document *map[string]interface{}) []byte {
	docCopy := *document
	delete(docCopy, "_id")
	jsonBytes, err := json.Marshal(docCopy)
	if err != nil {
		log.Printf("JSON编码失败: %v", err)
	}
	return jsonBytes
}

func SyncData(wg *sync.WaitGroup, mongoDBName, collectionName, schemaName, tableName string) {
	defer wg.Done()

	mongoDB := database.MongoClient.Database(mongoDBName)
	collection := mongoDB.Collection(collectionName)

	lastCheckpoint := getLastCheckpoint(schemaName, tableName)
	filter := bson.M{}
	if lastCheckpoint != "" {
		objectId, _ := primitive.ObjectIDFromHex(lastCheckpoint)
		filter["_id"] = bson.M{"$gt": objectId}
	}

	cursor, err := collection.Find(context.TODO(), filter, options.Find().SetSort(bson.M{"_id": 1}).SetBatchSize(1000))
	if err != nil {
		log.Printf("查找文档失败: %v", err)
		return
	}
	defer cursor.Close(context.TODO())

	concurrency := 10
	ch := make(chan struct{}, concurrency)

	var wgDocs sync.WaitGroup

	for cursor.Next(context.TODO()) {
		var document map[string]interface{}
		if err := cursor.Decode(&document); err != nil {
			log.Printf("解码文档失败: %v", err)
			continue
		}

		convertObjectIDToStr(&document)

		wgDocs.Add(1)
		ch <- struct{}{}

		go func(doc map[string]interface{}) {
			defer wgDocs.Done()
			defer func() { <-ch }()

			_, err := database.PgDB.Exec(fmt.Sprintf(`
            INSERT INTO %s.%s (__sys_obj_id__, doc, optype)
            VALUES ($1, $2, 'insert')
            ON CONFLICT (__sys_obj_id__) DO UPDATE SET doc = EXCLUDED.doc;
            `, schemaName, tableName),
				doc["_id"], dealJSONDeleteID(&doc))
			if err != nil {
				log.Printf("插入文档 %s 时出错: %v", doc["_id"], err)
				return
			}

			updateCheckpoint(schemaName, tableName, doc["_id"].(string))
		}(document)
	}

	wgDocs.Wait()

	if err := cursor.Err(); err != nil {
		log.Printf("游标错误: %v", err)
	}
}
