#include "graphics/Image.h"

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include <jpeglib.h>
#include <png.h>

#include "utilities/IOUtilities.h"

struct PNGImageAdaptorReadCallbackInfo {
	const void * data;
	size_t length;
	size_t index;
};

struct PNGImageAdaptorWriteCallbackInfo {
	void ** data;
	size_t * length;
	size_t index;
};

static void pngReadFnMemoryBlock(png_structp pngReadStruct, png_bytep data, png_size_t length) {
	struct PNGImageAdaptorReadCallbackInfo * callbackInfo;
	
	callbackInfo = (struct PNGImageAdaptorReadCallbackInfo *) png_get_io_ptr(pngReadStruct);
	if (!readBytesFromMemoryBlock(callbackInfo->data, callbackInfo->length, &callbackInfo->index, length, data)) {
		png_error(pngReadStruct, "Failed to read bytes in pngReadFnMemoryBlock.");
	}
}

#define PNG_HEADER_SIZE 8

Image * Image_loadFromPNG(const char * fileName, int pixelFormat, bool flipVertical) {
	void * fileContents;
	size_t fileLength;
	png_structp pngReadStruct;
	png_infop pngInfoStruct;
	png_byte headerBytes[PNG_HEADER_SIZE];
	struct PNGImageAdaptorReadCallbackInfo callbackInfo;
	Image * image;
	int width;
	int height;
	png_int_32 bitDepth, colorType;
	int rowIndex;
	png_byte ** rows;
	
	fileContents = readFileSimple(fileName, &fileLength);
	if (fileContents == NULL) {
		return NULL;
	}
	
	if (fileLength < PNG_HEADER_SIZE) {
		free(fileContents);
		return NULL;
	}
	memcpy(headerBytes, fileContents, PNG_HEADER_SIZE);
	if (png_sig_cmp(headerBytes, 0, PNG_HEADER_SIZE) != 0) {
		free(fileContents);
		return NULL;
	}
	
	pngReadStruct = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
	pngInfoStruct = png_create_info_struct(pngReadStruct);
	
	if (setjmp(png_jmpbuf(pngReadStruct))) {
		png_destroy_read_struct(&pngReadStruct, &pngInfoStruct, NULL);
		free(fileContents);
		return NULL;
	}
	
	callbackInfo.data = fileContents;
	callbackInfo.length = fileLength;
	callbackInfo.index = PNG_HEADER_SIZE;
	png_set_read_fn(pngReadStruct, &callbackInfo, pngReadFnMemoryBlock);
	
	png_set_sig_bytes(pngReadStruct, PNG_HEADER_SIZE);
	png_read_info(pngReadStruct, pngInfoStruct);
	
	width = png_get_image_width(pngReadStruct, pngInfoStruct);
	height = png_get_image_height(pngReadStruct, pngInfoStruct);
	bitDepth = png_get_bit_depth(pngReadStruct, pngInfoStruct);
	colorType = png_get_color_type(pngReadStruct, pngInfoStruct);
	
	if (colorType == PNG_COLOR_TYPE_PALETTE) {
		png_set_palette_to_rgb(pngReadStruct);
	}
	if (colorType == PNG_COLOR_TYPE_GRAY && bitDepth < 8) {
		png_set_gray_1_2_4_to_8(pngReadStruct);
	}
	if (colorType == PNG_COLOR_TYPE_GRAY ||
			colorType == PNG_COLOR_TYPE_GRAY_ALPHA) {
		png_set_gray_to_rgb(pngReadStruct);
	}
	if (png_get_valid(pngReadStruct, pngInfoStruct, PNG_INFO_tRNS)) {
		png_set_tRNS_to_alpha(pngReadStruct);
	} else {
		png_set_filler(pngReadStruct, 0xFF, PNG_FILLER_AFTER);
	}
	if (bitDepth == 16) {
		png_set_strip_16(pngReadStruct);
	}
	
	png_read_update_info(pngReadStruct, pngInfoStruct);
	
	image = malloc(sizeof(Image));
	image->width = width;
	image->height = height;
	image->pixelFormat = IMAGE_PIXEL_FORMAT_RGBA;
	image->pixels = malloc(width * height * 4);
	
	rows = (png_byte **) malloc(height * sizeof(png_byte *));
	for (rowIndex = 0; rowIndex < height; rowIndex++) {
		rows[rowIndex] = (image->pixels + ((flipVertical ? height - rowIndex - 1 : rowIndex) * width * 4));
	}
	
	png_read_image(pngReadStruct, rows);
	png_read_end(pngReadStruct, NULL);
	
	png_destroy_read_struct(&pngReadStruct, &pngInfoStruct, NULL);
	free(rows);
	free(fileContents);
	
	return image;
}

struct JPEGContext {
	jmp_buf * jmpEnv;
	// Workaround for gcc bug (complains that pixelFormat may be clobbered by longjmp otherwise)
	int pixelFormat;
};

static void jpegErrorExitCallback(j_common_ptr context) {
	jmp_buf * jmpEnv;
	
	jmpEnv = ((struct JPEGContext *) context->client_data)->jmpEnv;
	jpeg_destroy(context);
	longjmp(*jmpEnv, 1);
}

static void jpegOutputMessageCallback(j_common_ptr context) {
}

static void jpegInitSourceCallback(j_decompress_ptr context) {
}

static boolean jpegFillInputBufferCallback(j_decompress_ptr context) {
	return true;
}

static void jpegSkipInputDataCallback(j_decompress_ptr context, long numberOfBytes) {
	context->src->next_input_byte += numberOfBytes;
	context->src->bytes_in_buffer -= numberOfBytes;
}

static boolean jpegResyncToRestartCallback(j_decompress_ptr context, int desired) {
	return jpeg_resync_to_restart(context, desired);
}

static void jpegTermSourceCallback(j_decompress_ptr context) {
}

Image * Image_loadFromJPEG(const char * fileName, int pixelFormat, bool flipVertical) {
	void * fileContents;
	size_t fileLength;
	Image * image = NULL;
	struct jpeg_decompress_struct jpegDecompressStruct;
	struct jpeg_error_mgr jpegErrorManager;
	struct jpeg_source_mgr jpegSourceManager;
	jmp_buf jmpEnv;
	unsigned int rowIndex, columnIndex;
	JSAMPROW row = NULL;
	struct JPEGContext contextStruct;
	int outputRowIndex;
	
	if (pixelFormat == IMAGE_PIXEL_FORMAT_DEFAULT) {
		pixelFormat = IMAGE_PIXEL_FORMAT_RGB;
	}
	
	fileContents = readFileSimple(fileName, &fileLength);
	if (fileContents == NULL) {
		return NULL;
	}
	
	contextStruct.jmpEnv = &jmpEnv;
	contextStruct.pixelFormat = pixelFormat;
	
	jpegErrorManager.error_exit = jpegErrorExitCallback;
	jpegErrorManager.output_message = jpegOutputMessageCallback;
	jpegDecompressStruct.err = jpeg_std_error(&jpegErrorManager);
	jpegDecompressStruct.client_data = &contextStruct;
	jpeg_create_decompress(&jpegDecompressStruct);
	
	if (setjmp(jmpEnv) != 0) {
		if (image != NULL) {
			if (image->pixels != NULL) {
				free(image->pixels);
			}
			free(image);
		}
		if (row != NULL) {
			free(row);
		}
		return NULL;
	}
	
	jpegSourceManager.next_input_byte = fileContents;
	jpegSourceManager.bytes_in_buffer = fileLength;
	jpegSourceManager.init_source = jpegInitSourceCallback;
	jpegSourceManager.fill_input_buffer = jpegFillInputBufferCallback;
	jpegSourceManager.skip_input_data = jpegSkipInputDataCallback;
	jpegSourceManager.resync_to_restart = jpegResyncToRestartCallback;
	jpegSourceManager.term_source = jpegTermSourceCallback;
	jpegDecompressStruct.src = &jpegSourceManager;
	
	jpeg_read_header(&jpegDecompressStruct, true);
	
	jpegDecompressStruct.out_color_space = JCS_RGB;
	
	jpeg_start_decompress(&jpegDecompressStruct);
	
	if (jpegDecompressStruct.output_components != 3) {
		longjmp(jmpEnv, 1);
	}
	
	image = malloc(sizeof(Image));
	image->width = jpegDecompressStruct.output_width;
	image->height = jpegDecompressStruct.output_height;
	switch (contextStruct.pixelFormat) {
		case IMAGE_PIXEL_FORMAT_RGB:
			image->pixelFormat = IMAGE_PIXEL_FORMAT_RGB;
			image->pixels = malloc(image->width * image->height * 3);
			break;
		case IMAGE_PIXEL_FORMAT_RGBA:
			image->pixelFormat = IMAGE_PIXEL_FORMAT_RGBA;
			image->pixels = malloc(image->width * image->height * 4);
			break;
		default:
			longjmp(jmpEnv, 1);
			break;
	}
	
	row = malloc(sizeof(JSAMPLE) * image->width * 3);
	for (rowIndex = 0; rowIndex < image->height; rowIndex++) {
		jpeg_read_scanlines(&jpegDecompressStruct, &row, 1);
		outputRowIndex = (flipVertical ? image->height - rowIndex - 1 : rowIndex);
		switch (contextStruct.pixelFormat) {
			case IMAGE_PIXEL_FORMAT_RGB:
				for (columnIndex = 0; columnIndex < image->width; columnIndex++) {
					image->pixels[outputRowIndex * image->width * 3 + columnIndex * 3 + 0] = row[columnIndex * 3 + 0];
					image->pixels[outputRowIndex * image->width * 3 + columnIndex * 3 + 1] = row[columnIndex * 3 + 1];
					image->pixels[outputRowIndex * image->width * 3 + columnIndex * 3 + 2] = row[columnIndex * 3 + 2];
				}
				break;
			case IMAGE_PIXEL_FORMAT_RGBA:
				for (columnIndex = 0; columnIndex < image->width; columnIndex++) {
					image->pixels[outputRowIndex * image->width * 4 + columnIndex * 4 + 0] = row[columnIndex * 3 + 0];
					image->pixels[outputRowIndex * image->width * 4 + columnIndex * 4 + 1] = row[columnIndex * 3 + 1];
					image->pixels[outputRowIndex * image->width * 4 + columnIndex * 4 + 2] = row[columnIndex * 3 + 2];
					image->pixels[outputRowIndex * image->width * 4 + columnIndex * 4 + 3] = 0xFF;
				}
				break;
			default:
				break;
		}
	}
	free(row);
	
	jpeg_finish_decompress(&jpegDecompressStruct);
	jpeg_destroy_decompress(&jpegDecompressStruct);
	
	return image;
}

void Image_dispose(Image * image) {
	free(image->pixels);
	free(image);
}
