/*-------------------------------------------------------------------------
 *
 * sharding_conn.h
 *	  definitions of types, and declarations of functions, that are used in
 *	  table sharding connection functionality.
 *
 *
 * Copyright (c) 2019-2024 ZettaDB inc. All rights reserved.
 *
 * This source code is licensed under Apache 2.0 License,
 * combined with Common Clause Condition 1.0, as detailed in the NOTICE file.
 *
 * src/include/sharding/sharding_conn.h
 *
 *-------------------------------------------------------------------------
 */
#ifndef SHARDING_CONN_H
#define SHARDING_CONN_H

#include "sharding/mysql/mysql.h"
#include "sharding/mysql/server/private/sql_cmd.h"
#include "sharding/sharding.h"
#include "sharding/stmt_hdl.h"
#include "nodes/nodes.h"

#include <sys/time.h>

/* GUC options. */
extern int mysql_connect_timeout;
extern int mysql_read_timeout;
extern int mysql_write_timeout;
extern int mysql_max_packet_size;
extern bool mysql_transmit_compress;
extern bool enable_sql_log;
extern bool enable_remote_cursor;
extern int remote_cursor_prefetch_rows;
struct MYSQL;
/**
 * flag bits for AsyncStmtInfo::conn_flags
 * CONN_VALID	: Connection is valid. if not, need to reconnect at next use of the connection.
 * CON_RESET	: Connection is reset. if so, need to resend SET NAMES and cached.
 * 				  mysql session vars before sending any stmt through it
 * CONN_MASTER  : connected to a master
 * CONN_MVCC	: global MVCC effective on the connection
 * CONN_INUSE	: the asi is used by current stmt. cleared at end of every stmt
 */
#define CONN_VALID 0x1
#define CONN_RESET 0x2
#define CONN_MASTER 0x4
#define CONN_MVCC 0x8
#define CONN_INUSE 0x10

/*
  nodeid value space is 0x00ffffff, still big enough
  to every shard node for every client session, 256 connections can be started
  at most, this should be sufficient for parallel query execution.
*/
static inline bool IsSnapshotNode(Oid nodeid)
{
	return (nodeid & 0xff000000) != 0;
}


typedef struct ShardConnection
{
	Oid shard_id;	   // Connections to one shard's all nodes.
	uint8_t num_nodes; // number of payload slots in the 3 arrays below.
	/*
	  We may start multiple channels to the same shard node for parallel query execution,
	  apart from the head of the ShardConnection list, the rest are all parallel snapshot channels(PSC),
	  and the nodeid of each PSC is uniquely derived from the original one with num_parallel_conns[i]
	  OR'ed to highest byte of nodeids[i], so existing logic still works and we can also
	  always figure out a snapshot channel belongs to which shard node, and reuse connections to a target shard node.
	  These channels can be reused in later queries of later txns.

	  The head of list never contain PSCs, following ShardConnection elements all contain PSCs.
	*/
	struct ShardConnection *next_parallel_conns;
	/*
	 * Whenever a connection to storage node is made once, the nodeids[i],
	 * conns[i] and conn_flags[i] always belong to the same storage node.
	 * */
	Oid nodeids[MAX_NODES_PER_SHARD];  // shard node ids. insert in order
	MYSQL *conns[MAX_NODES_PER_SHARD]; // conn[i] is nodeids[i]'s connection.
	MYSQL conn_objs[MAX_NODES_PER_SHARD]; // append only
} ShardConnection;


/*
 * A communication port with one storage node, which mostly is a master.
 * It should be reset/cleared at start of each statement.
 * */
typedef struct AsyncStmtInfo
{
	Oid shard_id;
	Oid node_id; // when this is a snapshot channel, GetRawNodeId(node_id) == donor_node_id
	Oid donor_node_id; /* the original(raw) node id from which this channel took its snapshot from. */
	int status; // mysql async API wait&cont status.
	uint64 client_query_id;

	/*
	 * In this channel, the NO. of stmts executed correctly and got its/their
	 * results.
	 * */
	int executed_stmts;

	/* The mysql client conn */
	MYSQL *conn;
	//ShardConnection *shard_conn;// the one containing conn

	/*
	 * Inserted/Deleted/Modified rows of INSERT/DELETE/UPDATE stmts executed in
	 * current txn and NOT SET for returned NO. of rows for SELECT.
	 * */
	uint32_t txn_wrows;

	/*
	 * Inserted/Deleted/Modified rows of INSERT/DELETE/UPDATE stmts executed in
	 * current user stmt and NOT SET for returned NO. of rows for SELECT.
	 * */
	uint32_t stmt_wrows;

	/*
	 * NO. of warnings from storage node query execution. this field should
	 * be returned to pg's client, and in pg's equivalent of 'show warnings'
	 * we should fetch warnings from each storage node that executed part of
	 * the last stmt, and assemble them together as final result to client. TODO
	 */
	uint32_t nwarnings;

	/*
	  OR'ed CONN_XXX flag bits above.
	*/
	char conn_flags;

	/*
	 * Whether write(INSERT/UPDATE/DELETE) and read(SELECT) commands were
	 * executed in current pg stmt. Transaction mgr should accumulate the
	 * group of storage nodes written and read-only by collecting them from
	 * this object at end of each stmt, in order to do 2PC.
	 * */
	bool did_write;
	bool did_read;
	/*
	 * Set if an DDL is executed in this shard in current txn. Note that we use
	 * CMD_UTILITY to denote DDLs as pg does, but CMD_UTILITY includes many
	 * other types of commands, including CALL stmt. Maybe in future we need
	 * to distinguish stored proc/func CALL stmts from DDL stmts.
	 * */
	bool did_ddl;
	
	/* Indicate if a xa in progress on the conn */
	bool txn_in_progress;

	/* True if current conn is not the leader of parallel vnodes */
	bool snapshot_channel;
	/* 
	 * The thread id of the leader connection which generate the snapshot used by
	 * parallel worker process.
	 */
	uint64 snapshot_threadid;

	uint64 visible_version;

	/* The current statment work on */
	StmtHandle *curr_stmt;

	/* The pending statements to send */
	List *stmt_queue;

	/* The handle still in use */
	List *stmt_inuse;

	struct AsyncStmtInfo *next_asi;// to the same shard node

	/* ----- sql logger states -------- */
	struct timeval 	sqllog_starttime;
	StringInfoData sqllog_buf;
} AsyncStmtInfo;

typedef struct ASIListIterator
{
	AsyncStmtInfo *cur_asi;
	int ref_idx;
	int n_elems;
	int elem_idx;
} ASIListIterator;

/*
  init itr and move it to 1st ASI, and return its ptr.
*/
AsyncStmtInfo * begin_asi_iter(ASIListIterator *itr);
/*
  move itr to next ASI, and return its ptr. if no more ASI, return NULL, and itr is invalid.
*/
AsyncStmtInfo *asi_iter_next(ASIListIterator *itr);

inline static bool ASIAccessed(AsyncStmtInfo *asi)
{
	return asi->did_write || asi->did_ddl || asi->did_read;
}

inline static bool ASIReadOnly(AsyncStmtInfo *asi)
{
	return !asi->did_write && !asi->did_ddl && asi->did_read;
}

inline static bool ASIConnected(AsyncStmtInfo *asi)
{
	return asi && asi->conn != NULL;
}

inline static bool ASITxnInProgress(AsyncStmtInfo *asi)
{
	return asi && asi->txn_in_progress;
}

extern void InitShardingSession(void);

extern void ResetCommunicationHub(void);
extern void ResetCommunicationHubStmt(bool ended_clean);

extern AsyncStmtInfo *GetAsyncStmtInfo(Oid shardid);
extern AsyncStmtInfo *GetAsyncStmtInfoNode(Oid shardid, Oid shardNodeId, bool req_chk_onfail, bool want_master, bool want_psc);
extern int GetAsyncStmtInfoUsed(void);

extern bool GetParallelSnapshotNodeConn(Oid shardid, Oid nodeid, bool prefer_snapshot_node);
extern AsyncStmtInfo* AllocSnapshotChannel(Oid shardid, Oid nodeid);

/*
 * In RC mode, reset the global visible version at the start of the command.
 */
extern void ResetAsiGlobalVisibleVersion(void);

/**
 * @brief Append 'stmt' into asi's job queue. 'stmt' will be sent later when its turn comes.
 * 
 *  Same as send_stmt_async(), but not return a handle. Typically used to execute sql 
 *  that does not return results.
 */
extern void send_stmt_async_nowarn(AsyncStmtInfo *asi, char *stmt,
			    size_t stmt_len, CmdType cmd, bool ownsit, enum enum_sql_command sqlcom);

/**
 * @brief Append 'stmt' into asi's job queue. and wait for the completion of 'stmt'
 */
extern void send_remote_stmt_sync(AsyncStmtInfo *asi, char *stmt, size_t len,
				  CmdType cmdtype, bool owns_it, enum enum_sql_command sqlcom, int ignore_err);

/**
 * @brief Send statement to all the shard currently in use, and wait for the completion of the statement
 *
 *   @param written_only Only send to shards which is written is current transaction
 */
extern void send_stmt_to_all_inuse_sync(char *stmt, size_t len, CmdType cmdtype, bool owns_it,
					enum enum_sql_command sqlcom, bool written_only);
/**
 * @brief Send statement to all of the shards in current cluster, and wait for the completion of the statement
 */
extern void send_stmt_to_all_conns_sync(char *stmt, size_t len, CmdType cmdtype, bool owns_it,
				  enum enum_sql_command sqlcom);

/**
 * @brief Wait for all of the statements in queue to be completed
 */
extern void flush_all_stmts(void);

/**
 * @brief Cancel all of the statements in queue, and wait for the completion of the running statements
 * 
 *  Note: 
 *  This is used internally and is mainly used when the asi is finally cleaned up, so even if
 *  the running stmt eventually reports an error, or even if the connection is disconnected,
 *  no exception will be thrown.
 */
extern void cancel_all_stmts(void);

// extern uint64_t GetRemoteAffectedRows(void);
extern bool MySQLQueryExecuted(void);

extern Oid GetCurrentNodeOfShard(Oid shard);

extern bool IsConnMaster(AsyncStmtInfo *asi);
extern bool IsConnReset(AsyncStmtInfo *asi);
extern void disconnect_storage_shards(void);
extern void request_topo_checks_used_shards(void);

#endif // !SHARDING_CONN_H
