#include <string.h>
#include <stdint.h>
#include "xil_printf.h"
#include "crypto_helpers.h"

// --- I. SHA-1 IMPLEMENTATION (Minimal Context) ---

#define SHA1_BLOCK_SIZE 64
#define SHA1_DIGEST_SIZE 20

typedef struct {
    uint32_t state[5];
    uint32_t count[2];
    uint8_t buffer[SHA1_BLOCK_SIZE];
} SHA1Context;

static void SHA1_Transform(uint32_t state[5], const uint8_t buffer[SHA1_BLOCK_SIZE]);

static void SHA1_Init(SHA1Context *context) {
    context->state[0] = 0x67452301;
    context->state[1] = 0xEFCDAB89;
    context->state[2] = 0x98BADCFE;
    context->state[3] = 0x10325476;
    context->state[4] = 0xC3D2E1F0;
    context->count[0] = context->count[1] = 0;
}

static void SHA1_Update(SHA1Context *context, const uint8_t *data, size_t len) {
    size_t i, j;

    j = context->count[0];
    if ((context->count[0] += ((uint32_t)len << 3)) < j)
        context->count[1]++;

    context->count[1] += ((uint32_t)len >> 29);

    j = (j >> 3) & 63;

    if ((j + len) >= SHA1_BLOCK_SIZE) {
        memcpy(&context->buffer[j], data, (i = SHA1_BLOCK_SIZE - j));
        SHA1_Transform(context->state, context->buffer);

        for (; i + SHA1_BLOCK_SIZE <= len; i += SHA1_BLOCK_SIZE)
            SHA1_Transform(context->state, &data[i]);

        j = 0;
    } else {
        i = 0;
    }

    memcpy(&context->buffer[j], &data[i], len - i);
}

static void SHA1_Final(uint8_t digest[SHA1_DIGEST_SIZE], SHA1Context *context) {
    unsigned int i;
    uint8_t finalcount[8];
    uint8_t c;

    for (i = 0; i < 8; i++) {
        finalcount[i] = (uint8_t)((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8) ) & 0xFF);
    }

    c = 0x80;
    SHA1_Update(context, &c, 1);

    while ((context->count[0] & 504) != 448) {
        c = 0x00;
        SHA1_Update(context, &c, 1);
    }

    SHA1_Update(context, finalcount, 8);

    for (i = 0; i < SHA1_DIGEST_SIZE; i++) {
        digest[i] = (uint8_t)((context->state[i>>2] >> ((3 - (i & 3)) * 8) ) & 0xFF);
    }
}

static void SHA1_Transform(uint32_t state[5], const uint8_t buffer[SHA1_BLOCK_SIZE]) {
    uint32_t a = state[0], b = state[1], c = state[2], d = state[3], e = state[4];
    uint32_t w[80];
    uint32_t temp;
    int t;

    #define ROTL(x, n) (((x) << (n)) | ((x) >> (32 - (n))))

    for (t = 0; t < 16; t++) {
        w[t] = (buffer[t * 4] << 24) | (buffer[t * 4 + 1] << 16) | (buffer[t * 4 + 2] << 8) | buffer[t * 4 + 3];
    }
    for (t = 16; t < 80; t++) {
        w[t] = ROTL(w[t-3] ^ w[t-8] ^ w[t-14] ^ w[t-16], 1);
    }

    #define F(t) ((t) <= 19 ? ((b & c) | (~b & d)) : \
                  (t) <= 39 ? (b ^ c ^ d) : \
                  (t) <= 59 ? ((b & c) | (b & d) | (c & d)) : \
                              (b ^ c ^ d))
    #define K(t) ((t) <= 19 ? 0x5A827999 : \
                  (t) <= 39 ? 0x6ED9EBA1 : \
                  (t) <= 59 ? 0x8F1BBCDC : \
                              0xCA62C1D6)

    for (t = 0; t < 80; t++) {
        temp = ROTL(a, 5) + F(t) + e + K(t) + w[t];
        e = d;
        d = c;
        c = ROTL(b, 30);
        b = a;
        a = temp;
    }

    state[0] += a;
    state[1] += b;
    state[2] += c;
    state[3] += d;
    state[4] += e;
}


// --- II. BASE64 IMPLEMENTATION (Minimal Context) ---

static const char base64_chars[] =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

static char *base64_encode_internal(const unsigned char *input, int length, char *output) {
    int i = 0, j = 0, k = 0;
    unsigned char char_array_3[3], char_array_4[4];

    while (length--) {
        char_array_3[i++] = *(input++);
        if (i == 3) {
            char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
            char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
            char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
            char_array_4[3] = char_array_3[2] & 0x3f;

            for (i = 0; i < 4; i++) {
                output[k++] = base64_chars[char_array_4[i]];
            }
            i = 0;
        }
    }

    if (i) {
        for (j = i; j < 3; j++) {
            char_array_3[j] = '\0';
        }

        char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
        char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
        char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);

        for (j = 0; j < i + 1; j++) {
            output[k++] = base64_chars[char_array_4[j]];
        }

        while ((i++ < 3)) {
            output[k++] = '=';
        }
    }

    output[k] = '\0';
    return output;
}

// --- III. MAIN WEBSOCKET HELPER FUNCTION ---

/**
 * Executes the required cryptographic steps for the WebSocket handshake:
 * 1. Takes the client key + magic string as input.
 * 2. Calculates the SHA-1 hash (20 bytes).
 * 3. Base64 encodes the 20-byte hash into the 28-character Sec-WebSocket-Accept key.
 * 4. Places the result in the output buffer (must be at least 29 bytes long).
 */
void sha1_and_base64_encode(const char *input, char *output) {
    uint8_t sha1_digest[SHA1_DIGEST_SIZE];
    SHA1Context context;

    // 1. Calculate SHA-1 Hash
    SHA1_Init(&context);
    SHA1_Update(&context, (const uint8_t*)input, strlen(input));
    SHA1_Final(sha1_digest, &context);

    // 2. Base64 Encode the 20-byte hash
    base64_encode_internal(sha1_digest, SHA1_DIGEST_SIZE, output);

    // Output is 28 characters long, null-terminated.
}
