// A library to load and save 24 bit Windows bitmap files.
// Written by Nils Liaaen Corneliusen 2019.
// https://www.ignorantus.com
// License: CC0 1.0 Universal (CC0 1.0) Public Domain Dedication license

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <malloc.h>
#include <Windows.h>

#include <jpeglib.h>

#include "bmp.h"
#include "vec.h"

bmp_argb *bmp_alloc( int w, int h )
{
    if( w <= 0 || h <= 0 ) {
        printf( "Illegal dimensions %d*%d!\n", w, h );
        return NULL;
    }

    bmp_argb *bp = (bmp_argb *)malloc( sizeof(bmp_argb) );
    if( bp == NULL ) {
        perror( "malloc" );
        return NULL;
    }

    bp->w      = w;
    bp->h      = h;

    int h8  = (h+7)&~7;
    int w64 = ((w+63)&~63)*4;

    bp->stride = w64;

    bp->argb = (uint8_t *)malloc( bp->stride * h8 );
    if( bp->argb == NULL ) {
        perror( "malloc" );
        free( bp );
        bp = NULL;
    }

    return bp;
}

void bmp_free( bmp_argb *bp )
{
    free( bp->argb );
    free( bp );
}

uint8_t bmpheader[] = {
    // Bitmap file header
    0x42,0x4D,           // 0x00 'BM'
    0x00,0x00,0x00,0x00, // 0x02 total file size
    0x00,0x00,           // 0x06 reserved
    0x00,0x00,           // 0x08 reserved
    0x36,0x00,0x00,0x00, // 0x0a offset to bitmap data (54)
    // BITMAPINFOHEADER
    0x28,0x00,0x00,0x00, // 0x0e header size (40)
    0x00,0x00,0x00,0x00, // 0x12 bitmap width
    0x00,0x00,0x00,0x00, // 0x16 bitmap height, negative for top to bottom storage
    0x01,0x00,           // 0x1a bitplanes (1)
    0x18,0x00,           // 0x1c bits per pixel (24)
    0x00,0x00,0x00,0x00, // 0x1e compression method (0)
    0x00,0x00,0x00,0x00, // 0x22 raw bitmap size (0)
    0x13,0x0B,0x00,0x00, // 0x26 horizontal ppm (2835)
    0x13,0x0B,0x00,0x00, // 0x2a vertical ppm (2835)
    0x00,0x00,0x00,0x00, // 0x2e colors in palette (0)
    0x00,0x00,0x00,0x00  // 0x32 important colors used (0)
};

bmp_argb *bmp_load( const char *fname, bool flip )
{
    printf( "Loading bitmap %s\n", fname );

    FILE *fp = fopen( fname, "rb" );
    if( fp == NULL  ) {
        printf( "Could not open %s!\n", fname );
        return NULL;
    }

    uint8_t hdr[sizeof(bmpheader)];

    size_t rc = fread( hdr, 1, sizeof(bmpheader), fp );
    if( rc != sizeof(bmpheader) ) {
        printf( "Error reading header! Read %zu, expected %zu\n", rc, sizeof(bmpheader) );
        fclose( fp );
        return false;
    }

    if( hdr[0] != 0x42 || hdr[1] != 0x4D ) {
        printf( "Wrong header!\n" );
        fclose( fp );
        return false;
    }

    int w  = hdr[18]|(hdr[19]<<8);
    int h  = hdr[22]|(hdr[23]<<8);

    if( w <= 0 || h <= 0 ) {
        printf( "Illegal dimensions %d*%d!\n", w, h );
        fclose( fp );
        return NULL;
    }

    bmp_argb *bp = bmp_alloc( w, h );
    if( bp == NULL ) {
        fclose( fp );
        return NULL;
    }

    int rowbytes = ((bp->w*3)+3)&~0x03;

    uint8_t *row = (uint8_t *)_alloca( rowbytes );

    for( int y = 0; y < bp->h; y++ ) {

        int dsty = flip ? bp->h-1-y : y;
        uint32_t *argb = (uint32_t *)(bp->argb + dsty*bp->stride);

        rc = fread( row, 1, rowbytes, fp );
        if( rc != rowbytes ) {
            printf( "Error reading row %d! Got %zu, expected %d\n", y, rc, rowbytes );
            bmp_free( bp );
            fclose( fp );
            return NULL;
        }

        for( int x = 0; x < bp->w*3; x += 3 ) {
            *argb++ = (row[x+2]<<16)|(row[x+1]<<8)|row[x+0];
        }

    }

    fclose( fp );

    return bp;
}

#if 0
bool bmp_save( bmp_argb *bp, const char *fname, bool flip )
{
    int rowbytes = ((bp->w*3)+3)&~0x03;

//    printf( "Saving bitmap %s\n", fname );

    FILE *fp = fopen( fname, "wb" );
    if( fp == NULL ) {
        printf( "Could not create ouput file %s\n", fname );
        return false;
    }

    uint8_t hdr[sizeof(bmpheader)];

    memcpy( hdr, bmpheader, sizeof(bmpheader) );

    int size = sizeof(bmpheader) + rowbytes * bp->h;

    hdr[2] = (size>> 0)&0xff;
    hdr[3] = (size>> 8)&0xff;
    hdr[4] = (size>>16)&0xff;
    hdr[5] = (size>>24)&0xff;

    hdr[18] = (bp->w)&0xff;
    hdr[19] = ((bp->w)>>8)&0xff;

    hdr[22] = (bp->h)&0xff;
    hdr[23] = ((bp->h)>>8)&0xff;

    int rc = fwrite( hdr, 1, sizeof(bmpheader), fp );
    if( rc != sizeof(bmpheader) ) {
        printf( "Error writing header! Got %d, expected %d.\n", rc, (int)sizeof(bmpheader) );
        fclose( fp );
        return false;
    }

    uint8_t *row = (uint8_t *)_alloca( rowbytes );

    for( int y = 0; y < bp->h; y++ ) {

        int srcy = flip ? bp->h-1-y : y;
        uint32_t *argb = (uint32_t *)(bp->argb + srcy*bp->stride);

        for( int x = 0; x < bp->w*3; x += 3 ) {
            uint32_t pixel = *argb++;
            row[x+0] =  pixel     &0xff;
            row[x+1] = (pixel>>8) &0xff;
            row[x+2] = (pixel>>16)&0xff;
        }

        rc = fwrite( row, 1, rowbytes, fp );
        if( rc != rowbytes ) {
            printf( "Error writing row %d! Got %d, expected %d.\n", y, rc, rowbytes );
            fclose( fp );
            return false;
        }

    }

    fclose( fp );

    return true;
}
#endif

bool bmp_save( bmp_argb *bp, const char *fname, bool flip )
{
    int rowbytes = ((bp->w*3)+3)&~0x03;

//    printf( "Saving bitmap %s\n", fname );

    FILE *fp = fopen( fname, "wb" );
    if( fp == NULL ) {
        printf( "Could not create ouput file %s\n", fname );
        return false;
    }

    uint8_t hdr[sizeof(bmpheader)];

    memcpy( hdr, bmpheader, sizeof(bmpheader) );

    int size = sizeof(bmpheader) + rowbytes * bp->h;

    hdr[2] = (size>> 0)&0xff;
    hdr[3] = (size>> 8)&0xff;
    hdr[4] = (size>>16)&0xff;
    hdr[5] = (size>>24)&0xff;

    hdr[18] = (bp->w)&0xff;
    hdr[19] = ((bp->w)>>8)&0xff;

    hdr[22] = (bp->h)&0xff;
    hdr[23] = ((bp->h)>>8)&0xff;

    size_t rc = fwrite( hdr, 1, sizeof(bmpheader), fp );
    if( rc != sizeof(bmpheader) ) {
        printf( "Error writing header! Got %zu, expected %zu.\n", rc, sizeof(bmpheader) );
        fclose( fp );
        return false;
    }

	uint8_t *outbuf = (uint8_t *)malloc( rowbytes * bp->h );

    for( int y = 0; y < bp->h; y++ ) {

        int srcy = flip ? bp->h-1-y : y;
        uint32_t *argb = (uint32_t *)(bp->argb + srcy*bp->stride);
		uint8_t *row = outbuf + y *rowbytes;

        for( int x = 0; x < bp->w*3; x += 3 ) {
            uint32_t pixel = *argb++;
            row[x+0] =  pixel     &0xff;
            row[x+1] = (pixel>>8) &0xff;
            row[x+2] = (pixel>>16)&0xff;
        }

	}

	rc = fwrite( outbuf, 1, rowbytes*bp->h, fp );
	if( rc != rowbytes*bp->h ) {
		printf( "Error writing! Got %zu, expected %d.\n", rc, rowbytes*bp->h );
		free( outbuf );
		fclose( fp );
		return false;
	}

	free( outbuf );
    fclose( fp );

    return true;
}

FILE _iob[] = {*stdin, *stdout, *stderr};

extern "C" FILE * __cdecl __iob_func(void)
{
    return _iob;
}

bool bmp_save_jpg( bmp_argb *bp, const char *fname, bool flip, int quality )
{
    int rowbytes = ((bp->w*3)+3)&~0x03;

//    printf( "Saving bitmap %s\n", fname );

    struct jpeg_compress_struct cinfo;
    struct jpeg_error_mgr jerr;
    JSAMPROW row_pointer[1];

//    printf( "Saving jpg %s %d*%d\n", fname, bp->w, bp->h );

    FILE *fp = fopen( fname, "wb" );
    if( fp == NULL ) {
        printf( "Could not create ouput file %s\n", fname );
        return false;
    }

    cinfo.err = jpeg_std_error( &jerr );

    jpeg_create_compress( &cinfo );

    jpeg_stdio_dest(&cinfo, fp);

    cinfo.image_width      = bp->w;
    cinfo.image_height     = bp->h;
    cinfo.input_components = 3;
    cinfo.in_color_space   = JCS_RGB;

    jpeg_set_defaults( &cinfo );

    jpeg_set_quality( &cinfo, quality, TRUE );

    cinfo.comp_info[0].v_samp_factor = 1;
    cinfo.comp_info[0].h_samp_factor = 1;

    jpeg_start_compress( &cinfo, TRUE );

	uint8_t *row = (uint8_t *)malloc( bp->w * 3 );

    while( cinfo.next_scanline < cinfo.image_height ) {

        int dy = flip ? bp->h-1-cinfo.next_scanline : cinfo.next_scanline;

        uint32_t *src32 = (uint32_t *)(bp->argb + bp->stride * dy);

        for( int i = 0; i < bp->w*3; i += 3 ) {
            uint32_t pix = *src32++;
            row[i+0] = pix>>16;
            row[i+1] = pix>> 8;
            row[i+2] = pix>> 0;
        }

        row_pointer[0] = row;
        jpeg_write_scanlines( &cinfo, row_pointer, 1 );
    }

    jpeg_finish_compress( &cinfo );

    free( row );
    fclose( fp );

    jpeg_destroy_compress( &cinfo );

    return true;
}

void bmp_blend( bmp_argb *src, bmp_argb *dst, float wt )
{
	uint32_t *src32 = (uint32_t *)src->argb;
	uint32_t *dst32 = (uint32_t *)dst->argb;
	int stride32 = dst->stride/4;

	for( int y = 0; y < dst->h; y++ ) {

		uint32_t *srcrow = src32 + y*stride32;
		uint32_t *dstrow = dst32 + y*stride32;

		for( int x = 0; x < dst->w; x++ ) {

			uint32_t srcpixel8 = *srcrow++;
			float srcr = ((srcpixel8>>16)&0xff)/255.0f;
			float srcg = ((srcpixel8>> 8)&0xff)/255.0f;
			float srcb = ((srcpixel8>> 0)&0xff)/255.0f;
			v3 src3 = v3_set( srcr, srcg, srcb );

			uint32_t dstpixel8 = *dstrow;
			float dstr = ((dstpixel8>>16)&0xff)/255.0f;
			float dstg = ((dstpixel8>> 8)&0xff)/255.0f;
			float dstb = ((dstpixel8>> 0)&0xff)/255.0f;
			v3 dst3 = v3_set( dstr, dstg, dstb );

			v3 pixel = v3_mix( dst3, src3, wt );
			int r = (int)(pixel.x * 255.0f);
			int g = (int)(pixel.y * 255.0f);
			int b = (int)(pixel.z * 255.0f);
		    uint32_t pix = (r<<16)|(g<<8)|(b<<0);

		    *dstrow++ = pix;

		}


	}

}

void bmp_blend_2( bmp_argb *src, bmp_argb *dst, float wt_src )
{
	uint32_t *src32 = (uint32_t *)src->argb;
	uint32_t *dst32 = (uint32_t *)dst->argb;
	int stride32 = dst->stride/4;

	for( int y = 0; y < dst->h; y++ ) {

		uint32_t *srcrow = src32 + y*stride32;
		uint32_t *dstrow = dst32 + y*stride32;

		for( int x = 0; x < dst->w; x++ ) {

			uint32_t srcpixel8 = *srcrow++;
			float srcr = ((srcpixel8>>16)&0xff)/255.0f;
			float srcg = ((srcpixel8>> 8)&0xff)/255.0f;
			float srcb = ((srcpixel8>> 0)&0xff)/255.0f;
			float srca = ((srcpixel8>>24)&0xff)/255.0f;
			v3 src3 = v3_set( srcr, srcg, srcb );

			src3 = v3_mul1( src3, wt_src );

			uint32_t dstpixel8 = *dstrow;
			float dstr = ((dstpixel8>>16)&0xff)/255.0f;
			float dstg = ((dstpixel8>> 8)&0xff)/255.0f;
			float dstb = ((dstpixel8>> 0)&0xff)/255.0f;
			v3 dst3 = v3_set( dstr, dstg, dstb );

			v3 pixel = v3_mix( dst3, src3, srca );
			int r = (int)(pixel.x * 255.0f);
			int g = (int)(pixel.y * 255.0f);
			int b = (int)(pixel.z * 255.0f);
		    uint32_t pix = (r<<16)|(g<<8)|(b<<0);

		    *dstrow++ = pix;

		}


	}

}

void bmp_fade( bmp_argb *src, bmp_argb *dst, float wt )
{
	uint32_t *src32 = (uint32_t *)src->argb;
	uint32_t *dst32 = (uint32_t *)dst->argb;
	int stride32 = dst->stride/4;

	for( int y = 0; y < dst->h; y++ ) {

		uint32_t *srcrow = src32 + y*stride32;
		uint32_t *dstrow = dst32 + y*stride32;

		for( int x = 0; x < dst->w; x++ ) {

			uint32_t srcpixel8 = *srcrow++;
			float srcr = ((srcpixel8>>16)&0xff)/255.0f;
			float srcg = ((srcpixel8>> 8)&0xff)/255.0f;
			float srcb = ((srcpixel8>> 0)&0xff)/255.0f;
			v3 src3 = v3_set( srcr, srcg, srcb );

			v3 pixel = v3_mul1( src3, wt );

			int r = (int)(pixel.x * 255.0f);
			int g = (int)(pixel.y * 255.0f);
			int b = (int)(pixel.z * 255.0f);
		    uint32_t pix = (r<<16)|(g<<8)|(b<<0);

		    *dstrow++ = pix;

		}


	}

}
