#include <errno.h>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "nativeaudio/AudioOut.h"

#ifdef WIN32
#include <windows.h>
#define sleep(seconds) Sleep((seconds) * 1000)
#endif

static unsigned long testSampleCount;
static int16_t * testSamplesInt16;
static float * testSamplesFloat32;
static unsigned long testSampleIndex;
static unsigned int accumulatedFrameCount;
static AudioOut_sampleFormat hostFormat;
static AudioOut_sampleFormat transportFormat = {1, 44100, 2};

void outputCallbackSineInt(void * outSamples, unsigned int frameCount, void * context) {
	unsigned int frameIndex, channelIndex;
	int16_t * samples = outSamples, sample;
	float frameToSinAngle = M_PI * 2.0f / (transportFormat.sampleRate / 440.0f);
	float rangeMultiplier = (1 << (transportFormat.bytesPerSample * 8 - 1)) - 1;
	
	for (frameIndex = 0; frameIndex < frameCount; frameIndex++) {
		sample = sinf((accumulatedFrameCount + frameIndex) * frameToSinAngle) * rangeMultiplier;
		for (channelIndex = 0; channelIndex < transportFormat.channelCount; channelIndex++) {
			samples[frameIndex * transportFormat.channelCount + channelIndex] = sample;
		}
	}
	accumulatedFrameCount += frameCount;
}

void outputCallbackSineFloat(void * outSamples, unsigned int frameCount, void * context) {
	unsigned int frameIndex, channelIndex;
	float * samples = outSamples, sample;
	float frameToSinAngle = M_PI * 2.0f / (transportFormat.sampleRate / 440.0f);
	
	for (frameIndex = 0; frameIndex < frameCount; frameIndex++) {
		sample = sinf((accumulatedFrameCount + frameIndex) * frameToSinAngle);
		for (channelIndex = 0; channelIndex < transportFormat.channelCount; channelIndex++) {
			samples[frameIndex * transportFormat.channelCount + channelIndex] = sample;
		}
	}
	accumulatedFrameCount += frameCount;
}

void outputCallbackPCMFloat(void * outSamples, unsigned int frameCount, void * context) {
	unsigned int frameIndex, channelIndex;
	float * samples = outSamples, sample;
	
	for (frameIndex = 0; frameIndex < frameCount && testSampleIndex < testSampleCount; frameIndex++) {
		sample = testSamplesFloat32[testSampleIndex++];
		for (channelIndex = 0; channelIndex < transportFormat.channelCount; channelIndex++) {
			samples[frameIndex * transportFormat.channelCount + channelIndex] = sample;
		}
	}
}

void outputCallbackPCMInt(void * outSamples, unsigned int frameCount, void * context) {
	unsigned int frameIndex, channelIndex;
	int16_t * samples = outSamples, sample;
	
	for (frameIndex = 0; frameIndex < frameCount && testSampleIndex < testSampleCount; frameIndex++) {
		sample = testSamplesInt16[testSampleIndex++];
		for (channelIndex = 0; channelIndex < transportFormat.channelCount; channelIndex++) {
			samples[frameIndex * transportFormat.channelCount + channelIndex] = sample;
		}
	}
	for (; frameIndex < frameCount; frameIndex++) {
		samples[frameIndex * 2] = 0;
		samples[frameIndex * 2 + 1] = 0;
	}
}

static void printUsage(void) {
	fprintf(stderr, "Usage: testharness [--host-format <channels> <bytes per sample> <sample rate>]\n"
	                "                   [--transport-format <channels> <bytes per sample> <sample rate>]\n"
	                "                   [--file-output <file path to dump raw wave data>]\n" 
	                "                   [file path to play as raw wave data]\n");
}

int main(int argc, char ** argv) {
	bool hostFormatSpecified = false, transportFormatSpecified = false;
	for (int argIndex = 1; argIndex < argc; argIndex++) {
		if (!strcmp(argv[argIndex], "--help")) {
			printUsage();
			return EXIT_SUCCESS;
		}
		if (!strcmp(argv[argIndex], "--host-format")) {
			if (argIndex + 3 >= argc) {
				printUsage();
				return EXIT_FAILURE;
			}
			if (!sscanf(argv[argIndex + 1], "%u", &hostFormat.channelCount) ||
			    !sscanf(argv[argIndex + 2], "%u", &hostFormat.bytesPerSample) ||
			    !sscanf(argv[argIndex + 3], "%u", &hostFormat.sampleRate)) {
				printUsage();
				return EXIT_FAILURE;
			}
			hostFormatSpecified = true;
			argIndex += 3;
			
		} else if (!strcmp(argv[argIndex], "--transport-format")) {
			if (argIndex + 3 >= argc) {
				printUsage();
				return EXIT_FAILURE;
			}
			if (!sscanf(argv[argIndex + 1], "%u", &transportFormat.channelCount) ||
			    !sscanf(argv[argIndex + 2], "%u", &transportFormat.bytesPerSample) ||
			    !sscanf(argv[argIndex + 3], "%u", &transportFormat.sampleRate)) {
				printUsage();
				return EXIT_FAILURE;
			}
			transportFormatSpecified = true;
			argIndex += 3;
			
		} else if (!strcmp(argv[argIndex], "--file-output")) {
			if (argIndex + 1 >= argc) {
				printUsage();
				return EXIT_FAILURE;
			}
			AudioOut_setFileOutput(argv[++argIndex]);
			
		} else if (testSamplesInt16 == NULL) {
			size_t fileLength;
			unsigned int sampleIndex;
			FILE * file = fopen(argv[1], "rb");
			if (file == NULL) {
				fprintf(stderr, "Couldn't open \"%s\" (errno = %d)\n", argv[1], errno);
			}
			fseek(file, 0, SEEK_END);
			fileLength = ftell(file);
			fseek(file, 0, SEEK_SET);
			testSamplesInt16 = malloc(fileLength);
			fread(testSamplesInt16, 1, fileLength, file);
			fclose(file);
			
			testSampleCount = fileLength / 2;
			testSamplesFloat32 = malloc(sizeof(*testSamplesFloat32) * testSampleCount);
			for (sampleIndex = 0; sampleIndex < testSampleCount; sampleIndex++) {
				testSamplesFloat32[sampleIndex] = testSamplesInt16[sampleIndex] / 32768.0f;
			}
			
		} else {
			fprintf(stderr, "Unknown argument: %s\n", argv[argIndex]);
			printUsage();
			return EXIT_FAILURE;
		}
	}
	
	if (hostFormatSpecified) {
		printf("Requested host format: %u channel%s, %u samples per second, %u byte%s per sample\n", hostFormat.channelCount, hostFormat.channelCount == 1 ? "" : "s", hostFormat.sampleRate, hostFormat.bytesPerSample, hostFormat.bytesPerSample == 1 ? "" : "s");
		AudioOut_setHostFormat(hostFormat);
	}
	AudioOut_init("nativeaudio test harness");
	hostFormat = AudioOut_getHostFormat();
	printf("Got host format: %u channel%s, %u samples per second, %u byte%s per sample\n", hostFormat.channelCount, hostFormat.channelCount == 1 ? "" : "s", hostFormat.sampleRate, hostFormat.bytesPerSample, hostFormat.bytesPerSample == 1 ? "" : "s");
	
	if (!transportFormatSpecified) {
		transportFormat = hostFormat;
	}
	AudioOut_setTransportFormat(transportFormat);
	printf("Set transport format: %u channel%s, %u samples per second, %u byte%s per sample\n", transportFormat.channelCount, transportFormat.channelCount == 1 ? "" : "s", transportFormat.sampleRate, transportFormat.bytesPerSample, transportFormat.bytesPerSample == 1 ? "" : "s");
	
	if (testSamplesInt16 == NULL) {
		AudioOutCallback callback;
		switch (transportFormat.bytesPerSample) {
			case 1:
			case 2:
			case 3:
				callback = outputCallbackSineInt;
				break;
			case 4:
				callback = outputCallbackSineFloat;
				break;
			default:
				fprintf(stderr, "Unknown bytes per sample value: %u\n", transportFormat.bytesPerSample);
				return EXIT_FAILURE;
		}
		AudioOut_startOutput(callback, NULL);
		//usleep(100000);
		sleep(1);
	} else {
		AudioOutCallback callback;
		switch (transportFormat.bytesPerSample) {
			case 1:
			case 2:
			case 3:
				callback = outputCallbackPCMInt;
				break;
			case 4:
				callback = outputCallbackPCMFloat;
				break;
			default:
				fprintf(stderr, "Unknown bytes per sample value: %u\n", transportFormat.bytesPerSample);
				return EXIT_FAILURE;
		}
		AudioOut_startOutput(callback, NULL);
		usleep(testSampleCount * 10000 / 441);
		sleep(1);
	}
	AudioOut_stopOutput();
	return EXIT_SUCCESS;
}
