/*
 * SPDX-FileCopyrightText: 2020-2022 Espressif Systems (Shanghai) CO LTD
 *
 * SPDX-License-Identifier: Apache-2.0
 */
#include <sys/param.h>
#include "mbedtls/error.h"
#include "esp_mbedtls_dynamic_impl.h"

int __real_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len);
int __real_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len);
void __real_mbedtls_ssl_free(mbedtls_ssl_context *ssl);
int __real_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl);
int __real_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf);
int __real_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message);
int __real_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl);

int __wrap_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len);
int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len);
void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl);
int __wrap_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl);
int __wrap_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf);
int __wrap_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message);
int __wrap_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl);

static const char *TAG = "SSL TLS";

static int tx_done(mbedtls_ssl_context *ssl)
{
    if (!ssl->MBEDTLS_PRIVATE(out_left))
        return 1;

    return 0;
}

static int rx_done(mbedtls_ssl_context *ssl)
{
    if (!ssl->MBEDTLS_PRIVATE(in_msglen)) {
        return 1;
    }

    ESP_LOGD(TAG, "RX left %zu bytes", ssl->MBEDTLS_PRIVATE(in_msglen));
    return 0;
}

static int ssl_update_checksum_start( mbedtls_ssl_context *ssl,
                                       const unsigned char *buf, size_t len )
{
    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
#if defined(MBEDTLS_SHA256_C)
    ret = mbedtls_md_update( &ssl->handshake->fin_sha256, buf, len );
#endif
#if defined(MBEDTLS_SHA512_C)
    ret = mbedtls_md_update( &ssl->handshake->fin_sha384, buf, len );
#endif
    return ret;
}

static void ssl_handshake_params_init( mbedtls_ssl_handshake_params *handshake )
{
    memset( handshake, 0, sizeof( mbedtls_ssl_handshake_params ) );

#if defined(MBEDTLS_SHA256_C)
    mbedtls_md_init( &handshake->fin_sha256 );
    mbedtls_md_setup( &handshake->fin_sha256,
                    mbedtls_md_info_from_type(MBEDTLS_MD_SHA256),
                    0 );
    mbedtls_md_starts( &handshake->fin_sha256 );
#endif
#if defined(MBEDTLS_SHA512_C)
    mbedtls_md_init( &handshake->fin_sha384 );
    mbedtls_md_setup( &handshake->fin_sha384,
                    mbedtls_md_info_from_type(MBEDTLS_MD_SHA384),
                    0 );
    mbedtls_md_starts( &handshake->fin_sha384 );
#endif

    handshake->update_checksum = ssl_update_checksum_start;

#if defined(MBEDTLS_DHM_C)
    mbedtls_dhm_init( &handshake->dhm_ctx );
#endif
#if defined(MBEDTLS_ECDH_C)
    mbedtls_ecdh_init( &handshake->ecdh_ctx );
#endif
#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
    mbedtls_ecjpake_init( &handshake->ecjpake_ctx );
#if defined(MBEDTLS_SSL_CLI_C)
    handshake->ecjpake_cache = NULL;
    handshake->ecjpake_cache_len = 0;
#endif
#endif

#if defined(MBEDTLS_SSL_ECP_RESTARTABLE)
    mbedtls_x509_crt_restart_init( &handshake->ecrs_ctx );
#endif

#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
    handshake->sni_authmode = MBEDTLS_SSL_VERIFY_UNSET;
#endif

#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
    !defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
    mbedtls_pk_init( &handshake->peer_pubkey );
#endif
}

static int ssl_handshake_init( mbedtls_ssl_context *ssl )
{
    /* Clear old handshake information if present */
    if( ssl->transform_negotiate )
        mbedtls_ssl_transform_free( ssl->transform_negotiate );
    if( ssl->session_negotiate )
        mbedtls_ssl_session_free( ssl->session_negotiate );
    if( ssl->handshake )
        mbedtls_ssl_handshake_free( ssl );

    /*
     * Either the pointers are now NULL or cleared properly and can be freed.
     * Now allocate missing structures.
     */
    if( ssl->transform_negotiate == NULL )
    {
        ssl->transform_negotiate = mbedtls_calloc( 1, sizeof(mbedtls_ssl_transform) );
    }

    if( ssl->session_negotiate == NULL )
    {
        ssl->session_negotiate = mbedtls_calloc( 1, sizeof(mbedtls_ssl_session) );
    }

    if( ssl->handshake == NULL )
    {
        ssl->handshake = mbedtls_calloc( 1, sizeof(mbedtls_ssl_handshake_params) );
    }
#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
    /* If the buffers are too small - reallocate */

    handle_buffer_resizing( ssl, 0, MBEDTLS_SSL_IN_BUFFER_LEN,
                                    MBEDTLS_SSL_OUT_BUFFER_LEN );
#endif

    /* All pointers should exist and can be directly freed without issue */
    if( ssl->handshake == NULL ||
        ssl->transform_negotiate == NULL ||
        ssl->session_negotiate == NULL )
    {
        ESP_LOGD(TAG, "alloc() of ssl sub-contexts failed");

        mbedtls_free( ssl->handshake );
        mbedtls_free( ssl->transform_negotiate );
        mbedtls_free( ssl->session_negotiate );

        ssl->handshake = NULL;
        ssl->transform_negotiate = NULL;
        ssl->session_negotiate = NULL;

        return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
    }

    /* Initialize structures */
    mbedtls_ssl_session_init( ssl->session_negotiate );
    mbedtls_ssl_transform_init( ssl->transform_negotiate );
    ssl_handshake_params_init( ssl->handshake );

/*
 * curve_list is translated to IANA TLS group identifiers here because
 * mbedtls_ssl_conf_curves returns void and so can't return
 * any error codes.
 */
#if defined(MBEDTLS_ECP_C)
#if !defined(MBEDTLS_DEPRECATED_REMOVED)
    /* Heap allocate and translate curve_list from internal to IANA group ids */
    if ( ssl->conf->curve_list != NULL )
    {
        size_t length;
        const mbedtls_ecp_group_id *curve_list = ssl->conf->curve_list;

        for( length = 0;  ( curve_list[length] != MBEDTLS_ECP_DP_NONE ) &&
                          ( length < MBEDTLS_ECP_DP_MAX ); length++ ) {}

        /* Leave room for zero termination */
        uint16_t *group_list = mbedtls_calloc( length + 1, sizeof(uint16_t) );
        if ( group_list == NULL )
            return( MBEDTLS_ERR_SSL_ALLOC_FAILED );

        for( size_t i = 0; i < length; i++ )
        {
            const mbedtls_ecp_curve_info *info =
                        mbedtls_ecp_curve_info_from_grp_id( curve_list[i] );
            if ( info == NULL )
            {
                mbedtls_free( group_list );
                return( MBEDTLS_ERR_SSL_BAD_CONFIG );
            }
            group_list[i] = info->tls_id;
        }

        group_list[length] = 0;

        ssl->handshake->group_list = group_list;
        ssl->handshake->group_list_heap_allocated = 1;
    }
    else
    {
        ssl->handshake->group_list = ssl->conf->group_list;
        ssl->handshake->group_list_heap_allocated = 0;
    }
#endif /* MBEDTLS_DEPRECATED_REMOVED */
#endif /* MBEDTLS_ECP_C */

#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
#if !defined(MBEDTLS_DEPRECATED_REMOVED)
#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
    /* Heap allocate and translate sig_hashes from internal hash identifiers to
       signature algorithms IANA identifiers.  */
    if ( mbedtls_ssl_conf_is_tls12_only( ssl->conf ) &&
         ssl->conf->sig_hashes != NULL )
    {
        const int *md;
        const int *sig_hashes = ssl->conf->sig_hashes;
        size_t sig_algs_len = 0;
        uint16_t *p;

#if defined(static_assert)
        static_assert( MBEDTLS_SSL_MAX_SIG_ALG_LIST_LEN
                       <= ( SIZE_MAX - ( 2 * sizeof(uint16_t) ) ),
                       "MBEDTLS_SSL_MAX_SIG_ALG_LIST_LEN too big" );
#endif

        for( md = sig_hashes; *md != MBEDTLS_MD_NONE; md++ )
        {
            if( mbedtls_ssl_hash_from_md_alg( *md ) == MBEDTLS_SSL_HASH_NONE )
                continue;
#if defined(MBEDTLS_ECDSA_C)
            sig_algs_len += sizeof( uint16_t );
#endif

#if defined(MBEDTLS_RSA_C)
            sig_algs_len += sizeof( uint16_t );
#endif
            if( sig_algs_len > MBEDTLS_SSL_MAX_SIG_ALG_LIST_LEN )
                return( MBEDTLS_ERR_SSL_BAD_CONFIG );
        }

        if( sig_algs_len < MBEDTLS_SSL_MIN_SIG_ALG_LIST_LEN )
            return( MBEDTLS_ERR_SSL_BAD_CONFIG );

        ssl->handshake->sig_algs = mbedtls_calloc( 1, sig_algs_len +
                                                      sizeof( uint16_t ));
        if( ssl->handshake->sig_algs == NULL )
            return( MBEDTLS_ERR_SSL_ALLOC_FAILED );

        p = (uint16_t *)ssl->handshake->sig_algs;
        for( md = sig_hashes; *md != MBEDTLS_MD_NONE; md++ )
        {
            unsigned char hash = mbedtls_ssl_hash_from_md_alg( *md );
            if( hash == MBEDTLS_SSL_HASH_NONE )
                continue;
#if defined(MBEDTLS_ECDSA_C)
            *p = (( hash << 8 ) | MBEDTLS_SSL_SIG_ECDSA);
            p++;
#endif
#if defined(MBEDTLS_RSA_C)
            *p = (( hash << 8 ) | MBEDTLS_SSL_SIG_RSA);
            p++;
#endif
        }
        *p = MBEDTLS_TLS_SIG_NONE;
        ssl->handshake->sig_algs_heap_allocated = 1;
    }
    else
#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
    {
        ssl->handshake->sig_algs_heap_allocated = 0;
    }
#endif /* !MBEDTLS_DEPRECATED_REMOVED */
#endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */

    return( 0 );
}

int __wrap_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf)
{
    ssl->conf = conf;
    ssl->tls_version = ssl->conf->max_tls_version;

    CHECK_OK(ssl_handshake_init(ssl));

    mbedtls_free(ssl->MBEDTLS_PRIVATE(out_buf));
    ssl->MBEDTLS_PRIVATE(out_buf) = NULL;
    CHECK_OK(esp_mbedtls_setup_tx_buffer(ssl));

    mbedtls_free(ssl->MBEDTLS_PRIVATE(in_buf));
    ssl->MBEDTLS_PRIVATE(in_buf) = NULL;
    esp_mbedtls_setup_rx_buffer(ssl);

    return 0;
}

int __wrap_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len)
{
    int ret;

    CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0));

    ret = __real_mbedtls_ssl_write(ssl, buf, len);

    if (tx_done(ssl)) {
        CHECK_OK(esp_mbedtls_free_tx_buffer(ssl));
    }

    return ret;
}

int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len)
{
    int ret;

    ESP_LOGD(TAG, "add mbedtls RX buffer");
    ret = esp_mbedtls_add_rx_buffer(ssl);
    if (ret == MBEDTLS_ERR_SSL_CONN_EOF) {
        ESP_LOGD(TAG, "fail, the connection indicated an EOF");
        return 0;
    } else if (ret < 0) {
        ESP_LOGD(TAG, "fail, error=%d", -ret);
        return ret;
    }
    ESP_LOGD(TAG, "end");

    ret = __real_mbedtls_ssl_read(ssl, buf, len);

    if (rx_done(ssl)) {
        CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
    }

    return ret;
}

void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl)
{
    if (ssl->MBEDTLS_PRIVATE(out_buf)) {
        esp_mbedtls_free_buf(ssl->MBEDTLS_PRIVATE(out_buf));
        ssl->MBEDTLS_PRIVATE(out_buf) = NULL;
    }

    if (ssl->MBEDTLS_PRIVATE(in_buf)) {
        esp_mbedtls_free_buf(ssl->MBEDTLS_PRIVATE(in_buf));
        ssl->MBEDTLS_PRIVATE(in_buf) = NULL;
    }

    __real_mbedtls_ssl_free(ssl);
}

int __wrap_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl)
{
    CHECK_OK(esp_mbedtls_reset_add_tx_buffer(ssl));

    CHECK_OK(esp_mbedtls_reset_add_rx_buffer(ssl));

    CHECK_OK(__real_mbedtls_ssl_session_reset(ssl));

    CHECK_OK(esp_mbedtls_reset_free_tx_buffer(ssl));

    esp_mbedtls_reset_free_rx_buffer(ssl);

    return 0;
}

int __wrap_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message)
{
    int ret;

    CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0));

    ret = __real_mbedtls_ssl_send_alert_message(ssl, level, message);

    if (tx_done(ssl)) {
        CHECK_OK(esp_mbedtls_free_tx_buffer(ssl));
    }

    return ret;
}

int __wrap_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl)
{
    int ret;

    CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0));

    ret = __real_mbedtls_ssl_close_notify(ssl);

    if (tx_done(ssl)) {
        CHECK_OK(esp_mbedtls_free_tx_buffer(ssl));
    }

    return ret;
}