/* 
 *  WaveOutput.cpp 
 *         Original code by Timothy J. Weber.
 *
 *	Copyright (C) Alberto Vigata - January 2000 - ultraflask@yahoo.com
 *
 *  This file is part of FlasKMPEG, a free MPEG to MPEG/AVI converter
 *	
 *  FlasKMPEG is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2, or (at your option)
 *  any later version.
 *   
 *  FlasKMPEG is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *   
 *  You should have received a copy of the GNU General Public License
 *  along with GNU Make; see the file COPYING.  If not, write to
 *  the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA. 
 *
 */



#ifndef _MSC_VER
#include <stdlib.h>
#else
// Microsoft doesn't include min, though it's part of the standard library!
template<class T>
T min(T a, T b) { return a < b? a: b; }
#endif

#include "WaveOutput.h"

using namespace std;

/***************************************************************************
	macros and constants
***************************************************************************/

// constants for the canonical WAVE format
const int fmtChunkLength = 16;  // length of fmt contents
const int waveHeaderLength = 4 + 8 + fmtChunkLength + 8;  // from "WAVE" to sample data

/***************************************************************************
	typedefs and class definitions
***************************************************************************/

/***************************************************************************
	prototypes for static functions
***************************************************************************/

/***************************************************************************
	static variables
***************************************************************************/

/***************************************************************************
	public member functions for WaveFile
***************************************************************************/

WaveFile::WaveFile():
	readFile(0),
	writeFile(0),
	formatType(0),
	numChannels(0),
	sampleRate(0),
	bytesPerSecond(0),
	bytesPerSample(0),
	bitsPerChannel(0),
	dataLength(0),
	error(0),
	changed(true)
{
}

WaveFile::~WaveFile()
{
	Close();
}

bool WaveFile::OpenRead(const char* name)
{
	if (readFile || writeFile)
		Close();

	try {
		// open the RIFF file
		readFile = new RiffFile(name);
		if (!readFile->filep())
			throw error = "Couldn't open file";

		// read the header information
		if (strcmp(readFile->chunkName(), "RIFF")
			|| strcmp(readFile->subType(), "WAVE")
			|| !readFile->push("fmt "))
			throw error = "Couldn't find RIFF, WAVE, or fmt";

		size_t dwFmtSize = size_t(readFile->chunkSize());
		char* fmtChunk = new char[dwFmtSize];
		try {
			if (fread(fmtChunk, dwFmtSize, 1, readFile->filep()) != 1)
				throw error = "Error reading format chunk";
			readFile->pop();

			// set the format attribute members
			formatType = *((short*) fmtChunk);
			numChannels = *((short*) (fmtChunk + 2));
			sampleRate = *((long*) (fmtChunk + 4));
			bytesPerSecond = *((long*) (fmtChunk + 8));
			bytesPerSample = *((short*) (fmtChunk + 12));
			bitsPerChannel = *((short*) (fmtChunk + 14));

			// position at the data chunk
			if (!readFile->push("data"))
				throw error = "Couldn't find data chunk";

			// get the size of the data chunk
			dataLength = readFile->chunkSize();

			delete[] fmtChunk;
		} catch (...) {
			delete[] fmtChunk;
			throw error;
		}
	} catch (...) {
		Close();
		return false;
	}
	return true;
}

bool WaveFile::OpenWrite(const char* name)
{
	if (readFile || writeFile)
		Close();

	// open the file
	writeFile = fopen(name, "wb");
	if (!writeFile) {
		error = "Couldn't open output file";
		return false;
	}

	// write the header
	return WriteHeaderToFile(writeFile);
}

bool WaveFile::Write(const char* data, unsigned int dataCount)
{
    if( fwrite( data, dataCount, 1, GetFile() ) )
    {
        SetDataLength( GetDataLength() + dataCount );
        return true;
    }
    else
        return false;
}
bool WaveFile::ResetToStart()
{
	if (readFile) {
		// pop out of the data chunk
		if (!readFile->rewind()
			|| !readFile->push("data"))
		{
			error = "Couldn't find data chunk on reset";
			return false;
		} else
			return true;
	} else if (writeFile) {
		return fseek(writeFile, waveHeaderLength, SEEK_SET) == 0;
	} else
		return false;
}

bool WaveFile::Close()
{
	bool retval = true;

	if (readFile) {
		delete readFile;  // closes the file before it's destroyed
		readFile = 0;
	} else if (writeFile) {
		// write the header information at the start of the file, if necessary
		if (changed) {
			long currentSpot = ftell(writeFile);  // save the position
			retval = WriteHeaderToFile(writeFile);
			fseek(writeFile, currentSpot, SEEK_SET);  // restore the old position
				// this is necessary so the file gets the right length--otherwise,
				// all the data we wrote would be truncated.
		}

		// close the file
		fclose(writeFile);
		writeFile = 0;
	}

	return retval;
}

bool WaveFile::FormatMatches(const WaveFile& other)
{
	return formatType == other.formatType
		&& numChannels == other.numChannels
		&& sampleRate == other.sampleRate
		&& bytesPerSecond == other.bytesPerSecond
		&& bytesPerSample == other.bytesPerSample
		&& bitsPerChannel == other.bitsPerChannel;
}

void WaveFile::CopyFormatFrom(const WaveFile& other)
{
	formatType = other.formatType;
	numChannels = other.numChannels;
	sampleRate = other.sampleRate;
	bytesPerSecond = other.bytesPerSecond;
	bytesPerSample = other.bytesPerSample;
	bitsPerChannel = other.bitsPerChannel;
}

bool WaveFile::GetFirstExtraItem(string& type, string& value)
{
	if (readFile)
		return readFile->rewind() && readFile->getNextExtraItem(type, value);
	else
		return false;
}

bool WaveFile::GetNextExtraItem(string& type, string& value)
{
	if (readFile)
		return readFile->getNextExtraItem(type, value);
	else
		return false;
}

bool WaveFile::CopyFrom(WaveFile& other)
{
	const size_t transferBufSize = 4096;

	if (!writeFile) {
		error = "Copy to an unopened file";
		return false;
	} else if (!other.readFile) {
		error = "Copy from an unopened file";
		return false;
	}

	try {
		// allocate the transfer buffer
		char* transferBuffer = new char[transferBufSize];
		unsigned long bytesRead = 0;

		try {
			if (!other.ResetToStart())
				throw error = "Couldn't reset input file to start";

			while (bytesRead < other.dataLength) {
				// calculate the size of the next buffer
				size_t bytesToRead = (size_t) min(transferBufSize,
					size_t(other.dataLength - bytesRead));

				// read the buffer
				if (fread(transferBuffer, 1, bytesToRead, other.readFile->filep())
					!= bytesToRead)
					throw error = "Error reading samples from input file";
				bytesRead += bytesToRead;

				// write the buffer
				if (fwrite(transferBuffer, 1, bytesToRead, writeFile) != bytesToRead)
					throw error = "Error writing samples to output file";
				dataLength += bytesToRead;
				changed = true;
			}

			// delete the transfer buffer
			delete[] transferBuffer;
		} catch (...) {
			delete[] transferBuffer;
			throw error;
		}
	} catch (...) {
		return false;
	}

	return true;
}

bool WaveFile::WriteHeaderToFile(FILE* fp)
{
	// seek to the start of the file
	if (fseek(fp, 0, SEEK_SET) != 0)
		return false;

	// write the file header
	unsigned long wholeLength = waveHeaderLength + dataLength;
	unsigned long chunkLength = fmtChunkLength;

	if (fputs("RIFF", fp) == EOF
		|| fwrite(&wholeLength, sizeof(wholeLength), 1, fp) != 1
		|| fputs("WAVE", fp) == EOF
		|| fputs("fmt ", fp) == EOF
		|| fwrite(&chunkLength, sizeof(chunkLength), 1, fp) != 1
		|| fwrite(&formatType, sizeof(formatType), 1, fp) != 1
		|| fwrite(&numChannels, sizeof(numChannels), 1, fp) != 1
		|| fwrite(&sampleRate, sizeof(sampleRate), 1, fp) != 1
		|| fwrite(&bytesPerSecond, sizeof(bytesPerSecond), 1, fp) != 1
		|| fwrite(&bytesPerSample, sizeof(bytesPerSample), 1, fp) != 1
		|| fwrite(&bitsPerChannel, sizeof(bitsPerChannel), 1, fp) != 1
		|| fputs("data", fp) == EOF
		|| fwrite(&dataLength, sizeof(dataLength), 1, fp) != 1)
	{
		error = "Error writing header";
		return false;
	}

	// if it's the same file, now we don't have to write it again unless it's
	// been changed.
	if (fp == writeFile)
		changed = false;

	return true;
}

/***************************************************************************
	private member functions for WaveFile
***************************************************************************/

/***************************************************************************
	main()
***************************************************************************/

#ifdef TEST_WAVE

#include <iostream>

static void reportProblem()
{
	cout << "  *** ERROR: Result incorrect." << endl;
}

static void checkResult(bool got, bool expected)
{
	if (got)
		cout << "success." << endl;
	else 
		cout << "fail." << endl;

	if (got != expected)
		reportProblem();
}

static void pause()
{
	cout << "Press Enter to continue." << endl;
	cin.get();
}

static void ShowErrors(WaveFile& from, WaveFile& to)
{
	bool any = from.GetError() || to.GetError();

	if (from.GetError())
		cout << "Error on input: " << from.GetError() << "." << endl;

	if (to.GetError())
		cout << "Error on output: " << to.GetError() << "." << endl;

	if (!any)
		cout << "Success." << endl;
}

static void ShowFormat(WaveFile& wave)
{
	cout
		<< "Format:           " << wave.GetFormatType()
		<< (wave.IsCompressed()? " (compressed)" : " (PCM)") << endl
		<< "Channels:         " << wave.GetNumChannels() << endl
		<< "Sample rate:      " << wave.GetSampleRate() << endl
		<< "Bytes per second: " << wave.GetBytesPerSecond() << endl
		<< "Bytes per sample: " << wave.GetBytesPerSample() << endl
		<< "Bits per channel: " << wave.GetBitsPerChannel() << endl
		<< "Bytes:            " << wave.GetDataLength() << endl
		<< "Samples:          " << wave.GetNumSamples() << endl
		<< "Seconds:          " << wave.GetNumSeconds() << endl
		<< "File pointer:     " << (wave.GetFile()? "good" : "null") << endl;

	string type, value;
	if (wave.GetFirstExtraItem(type, value)) {
		cout << "Extra data:" << endl;
		do {
			cout << "  " << type << ": " << value << endl;
		} while (wave.GetNextExtraItem(type, value));

		wave.ResetToStart();
	}

	pause();
}

int main(int argc, const char* argv[])
{
	if (argc < 3)
		cout << "Copies one WAVE file to another, in canonical form." << endl;
	else {
		WaveFile From, To;

		cout << "Opening input..." << endl;
		From.OpenRead(argv[1]);
		ShowErrors(From, To);
		ShowFormat(From);

		cout << "Setting formats..." << endl;
		To.CopyFormatFrom(From);
		ShowFormat(To);

		cout << "Opening output..." << endl;
		To.OpenWrite(argv[2]);
		ShowErrors(From, To);

		cout << "Copying..." << endl;
		To.CopyFrom(From);
		ShowErrors(From, To);
		cout << "Resulting format: " << endl;
		ShowFormat(To);
		cout << "Source format: " << endl;
		ShowFormat(From);

		cout << "Closing..." << endl;
		To.Close();
		From.Close();
		ShowErrors(From, To);
	}

	return 0;
}

#endif
