/*-------------------------------------------------------------------------
 *
 * mysql_auth.h
 *		Kunlun Database MySQL protocol server side implementation.
 *		Authentication types and symbols.
 *
 * Copyright (c) 2019-2022 ZettaDB inc. All rights reserved.
 *
 * This source code is licensed under Apache 2.0 License,
 * combined with Common Clause Condition 1.0, as detailed in the NOTICE file.
 *
 * IDENTIFICATION
 *	  src/include/libmysql/mysql_auth.h
 *
 *-------------------------------------------------------------------------
 */
#ifndef MYSQL_AUTH_INCLUDED
#define MYSQL_AUTH_INCLUDED
#include "postgres.h"
#include <openssl/rsa.h>
#include <stddef.h>
#include <sys/types.h>

#include "lib/stringinfo.h"

#include "libmysql/com_common.h"
#include "libmysql/plugin_auth.h"  // MYSQL_SERVER_AUTH_INFO
#include "libmysql/plugin_auth_common.h"
#include "libmysql/mysql_com.h"
#include "libmysql/mysql_conn.h"
#include "libmysql/string_utils.h"

struct Port;
struct st_mysql_auth;
typedef struct st_mysql_auth* plugin_ref;

#define PROTOCOL_VERSION 10
extern int protocol_version;
extern char *opt_protocol_compression_algorithms;
extern char server_version[SERVER_VERSION_LENGTH];

struct MySQLSession;

typedef struct Acl_credential {
  StringInfoData m_auth_string;// the user account's password
  /**
    The salt variable is used as the password hash for
    native_password_authetication.
  */
  uint8_t m_salt[SCRAMBLE_LENGTH + 1];  // scrambled password in binary form
  /**
    In the old protocol the salt_len indicated what type of autnetication
    protocol was used: 0 - no password, 4 - 3.20, 8 - 4.0,  20 - 4.1.1
  */
  uint8_t m_salt_len;
} Acl_credential;

inline static void init_acl_credential(Acl_credential*ac, MemoryContext memctx)
{
  initStringInfo2(&ac->m_auth_string, SMALL_STRBUF_SZ, memctx);
  memset(ac->m_salt, 0, SCRAMBLE_LENGTH + 1);
  ac->m_salt_len = 0;
}

// Primarily fields of a pg_hba.conf record
typedef struct ACL_USER
{
  Acl_credential credential;
  uint32_t int_ip;
  Choice use_ssl;
  enum SSL_type ssl_type;
  const char *ssl_cipher, *x509_issuer, *x509_subject;
  StringInfoData user, db, host, ip;
  StringInfoData client_auth_plugin_name;
} ACL_USER;

inline static void init_acl_user(ACL_USER*u, MemoryContext memctx, Choice use_ssl)
{
  memset(u, 0, sizeof(*u));
  init_acl_credential(&u->credential, memctx);
  u->use_ssl = use_ssl;
  u->int_ip = 0;
  // TODO: below fields may need set here.
  initStringInfo2(&u->user, SMALL_STRBUF_SZ, memctx);
  initStringInfo2(&u->db, SMALL_STRBUF_SZ, memctx);
  initStringInfo2(&u->host, SMALL_STRBUF_SZ, memctx);
  initStringInfo2(&u->ip, SMALL_STRBUF_SZ, memctx);
  initStringInfo2(&u->client_auth_plugin_name, SMALL_STRBUF_SZ, memctx);
  // for now only support this one.
  appendBinaryStringInfo(&u->client_auth_plugin_name,
  	STRING_WITH_LEN("mysql_native_password"));
}

typedef enum MySQLAuthResult
  { MySQLAuth_SUCCESS, MySQLAuth_FAILURE, MySQLAuth_RESTART } MySQLAuthResult;

/**
  The internal version of what plugins know as MYSQL_PLUGIN_VIRTIO,
  basically the context of the authentication session
*/
typedef struct MySQLAuth
{
  MYSQL_PLUGIN_VIRTIO comm_channel;
  MYSQL_SERVER_AUTH_INFO auth_info;
  
  MemoryContext memctx; // alloc all below ptr fields in this memctx, it's valid only during authentication phase.
  ACL_USER *acl_user;

  StringInfoData db;      ///< db name from the handshake packet
  StringInfoData schema;      ///< schema name from the handshake packet, may be empty even if db is filled.

  /** when restarting a plugin this caches the last client reply */
  struct {
    const char *plugin, *pkt;
	uint32_t pkt_len;
  } cached_client_reply;

  /** this caches the first plugin packet for restart request on the client */
  struct {
    char *pkt;
    unsigned int pkt_len;
  } cached_server_packet;

  int packets_read, packets_written;  ///< counters for send/received packets
  /** when plugin returns a failure this tells us what really happened */
  MySQLAuthResult status;
  plugin_ref auth_plugin;
  /* encapsulation members */
  StringInfoData scramble;
  struct rand_struct *rand;
  pid_t thread_id;
  unsigned int *server_status;
  struct MySQLProto*protocol;
  
  unsigned long max_client_packet_length;
  StringInfoData ip;
  StringInfoData host;
  int virtio_is_encrypted;
  /*
    To access pg's user/hba facilities for access control and authentication,
	we have to use pg_port after filling its fields using fields in this class.
  */
  struct Port *pg_port;
} MySQLAuth;

inline static StringInfo client_auth_plugin_name(MySQLAuth*auth)
{
  return &auth->auth_plugin->plugin_name;
}

inline static StringInfo acl_user_auth_plugin_name(MySQLAuth*auth)
{
  return &auth->acl_user->client_auth_plugin_name;
}

#if SUPPORT_RSA_AUTH
bool init_rsa_keys(void);
void deinit_rsa_keys(void);
int show_rsa_public_key(THD *thd, SHOW_VAR *var, char *buff);

typedef struct rsa_st RSA;
class Rsa_auth_keys {
 private:
  RSA *m_public_key;
  RSA *m_private_key;
  int m_cipher_len;
  char *m_pem_public_key;
  char **m_private_key_path;
  char **m_public_key_path;

  void get_key_file_path(char *key, String *key_file_path);
  bool read_key_file(RSA **key_ptr, bool is_priv_key, char **key_text_buffer);

 public:
  Rsa_auth_keys(char **private_key_path, char **public_key_path)
      : m_public_key(nullptr),
        m_private_key(nullptr),
        m_cipher_len(0),
        m_pem_public_key(nullptr),
        m_private_key_path(private_key_path),
        m_public_key_path(public_key_path) {}
  ~Rsa_auth_keys() = default;

  void free_memory();
  void *allocate_pem_buffer(size_t buffer_len);
  RSA *get_private_key() { return m_private_key; }

  RSA *get_public_key() { return m_public_key; }

  int get_cipher_length();
  bool read_rsa_keys();
  const char *get_public_key_as_pem(void) { return m_pem_public_key; }
};
#endif // SUPPORT_RSA_AUTH


/* Data Structures */

typedef enum {
  PLUGIN_MYSQL_NATIVE_PASSWORD,
  /* Add new plugin before this */
  PLUGIN_LAST
} cached_plugins_enum;

typedef struct Cached_auth_plugins
{
  plugin_ref cached_plugins[(unsigned int)PLUGIN_LAST];
  StringInfoData cached_plugins_names[(unsigned int)PLUGIN_LAST];
  bool m_valid;
} Cached_auth_plugins;


extern bool compare_plugin(Cached_auth_plugins *plugins, cached_plugins_enum plugin_index,
                             StringInfo plugin);
extern const char *get_plugin_name(Cached_auth_plugins *plugins, cached_plugins_enum plugin_index);
extern bool auth_plugin_is_built_in(Cached_auth_plugins *plugins, StringInfo plugin);
extern plugin_ref get_cached_plugin_ref_by_name(Cached_auth_plugins *plugins, const StringInfo plugin);
extern plugin_ref get_cached_plugin_ref(Cached_auth_plugins *plugins, cached_plugins_enum plugin_index);
extern void Deinit_Cached_auth_plugins(void);
extern void Init_Cached_auth_plugins(void);

extern Cached_auth_plugins *g_cached_authentication_plugins;
extern int set_default_auth_plugin(char *plugin_name, size_t plugin_name_length);
extern int mysql_authenticate(struct MySQLSession *thd, enum enum_server_command command);
extern bool acl_check_host(struct MySQLSession *thd, const char *host, const char *ip);
extern bool valid_mysql_compression_algo(const char*algo);
extern bool push_top_search_path(struct MySQLSession *thd, const char *schm, bool missingOK);
extern void update_declared_mysql_server_version(const char*newval, void *extra);

#endif /* MYSQL_AUTH_INCLUDED */
