/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.io;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.time.LocalDate;
import java.util.UUID;

/**
 * DataOutputStream which writes data to a byte array and doesn't throw
 * IOExceptions.
 */
public class BytesDataOutput implements DataOutput
{
	private static final int P7 = 1 << 7;
	private static final int P14 = 1 << 14;
	private static final int P21 = 1 << 21;
	private static final int P28 = 1 << 28;
	private static final long P35 = 1L << 35;
	private static final long P42 = 1L << 42;
	private static final long P49 = 1L << 49;
	private static final long P56 = 1L << 56;
	
	private byte[] data;
    private int length;

    public BytesDataOutput() {
        data = new byte[16];
    }
	
	public int length() {
		return length;
	}
	
    public byte[] toByteArray() {
        return Arrays.copyOf(data, length);
    }
	
	public void clear() {
		length = 0;
	}
	
	// ==================================
	// === IDataOutput Implementation ===
	// ==================================
	
	@Override
	public void writeBoolean(boolean value) {
		writeUByte(value ? 1 : 0);
	}

	@Override
    public void writeSByte(byte value) {
        if (length + 1 > data.length)
			expand();
		
        data[length++] = value;
    }
	
	@Override
	public void writeUByte(int value) {
		writeSByte((byte) value);
	}

	@Override
    public void writeShort(short value) {
        if (length + 2 > data.length)
			expand();
		
        data[length++] = (byte)(value >>> 8);
        data[length++] = (byte)(value);
    }
	
	@Override
	public void writeUShort(int value) {
		writeShort((short) value);
	}

	@Override
    public void writeInt(int value) {
        if (length + 4 > data.length)
			expand();
		
        data[length++] = (byte)(value >>> 24);
        data[length++] = (byte)(value >>> 16);
        data[length++] = (byte)(value >>> 8);
        data[length++] = (byte)(value);
    }
	
	@Override
	public void writeUInt(int value) {
		writeInt(value);
	}

	@Override
    public void writeLong(long value) {
        if (length + 8 > data.length)
			expand();
		
        data[length++] = (byte)(value >>> 56);
        data[length++] = (byte)(value >>> 48);
        data[length++] = (byte)(value >>> 40);
        data[length++] = (byte)(value >>> 32);
        data[length++] = (byte)(value >>> 24);
        data[length++] = (byte)(value >>> 16);
        data[length++] = (byte)(value >>> 8);
        data[length++] = (byte)(value);
    }
	
	@Override
	public void writeULong(long value) {
		writeLong(value);
	}

	@Override
	public void writeVarInt(int value) {
		writeVarUInt(value < 0 ? (1 - value << 1) + 1 : value << 1);
	}
	
	@Override
	public void writeVarUInt(int value) {
		if (length + 5 > data.length)
			expand();
		
		if (value >= 0 && value < P28) {
			if ( value < P7) {
				data[length++] = (byte) (value & 0x7F);
			} else if (value < P14) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 7) & 0x7F);
			} else if (value < P21) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 14) & 0x7F);
			} else {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 21) & 0x7F);
			}
		} else {
			data[length++] = (byte) ((value & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 7) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 14) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 21) & 0x7F) | 0x80);
			data[length++] = (byte) ((value >>> 28) & 0x7F);
		}
	}
	
	@Override
	public void writeVarLong(long value) {
		writeVarULong(value < 0L ? (1L - value << 1) + 1L : value << 1);
	}
	
	@Override
	public void writeVarULong(long value) {
		if (length + 9 > data.length)
			expand();
		
		if (value >= 0 && value < P56) {
			if ( value < P7) {
				data[length++] = (byte) (value & 0x7F);
			} else if (value < P14) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 7) & 0x7F);
			} else if (value < P21) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 14) & 0x7F);
			} else if (value < P28) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 21) & 0x7F);
			} else if (value < P35) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 21) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 28) & 0x7F);
			} else if (value < P42) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 21) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 28) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 35) & 0x7F);
			} else if (value < P49) {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 21) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 28) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 35) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 42) & 0x7F);
			} else {
				data[length++] = (byte) ((value & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 7) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 14) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 21) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 28) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 35) & 0x7F) | 0x80);
				data[length++] = (byte) (((value >> 42) & 0x7F) | 0x80);
				data[length++] = (byte) ((value >> 49) & 0x7F);
			}
		} else {
			data[length++] = (byte) ((value & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 7) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 14) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 21) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 28) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 35) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 42) & 0x7F) | 0x80);
			data[length++] = (byte) (((value >>> 49) & 0x7F) | 0x80);
			data[length++] = (byte) (value >>> 56);
		}
	}
	
	@Override
    public void writeFloat(float value) {
        writeInt(Float.floatToIntBits(value));
    }

	@Override
    public void writeDouble(double value)
	{
        writeLong(Double.doubleToLongBits(value));
    }

	@Override
	public void writeChar(int value)
	{
		writeVarUInt(value);
	}
	
	@Override
    public void writeBytes(byte[] data)
	{
        writeVarUInt(data.length);
        writeRawBytes(data);
    }
	
	@Override
	public void writeBytes(byte[] data, int offset, int length)
	{
		writeVarUInt(length);
		writeRawBytes(data, offset, length);
	}
	
	@Override
    public void writeString(String str)
	{
		writeBytes(str.getBytes(StandardCharsets.UTF_8));
    }
	
	@Override
	public void writeRawBytes(byte[] value) {
        while (length + value.length >= data.length)
			expand();
		
        for (byte b : value)
			data[length++] = b;
    }
	
	@Override
	public void writeRawBytes(byte[] value, int offset, int length)
	{
		if (offset + length > value.length)
			throw new IndexOutOfBoundsException();
		
        while (this.length + length >= data.length)
			expand();
		
		for (int i = 0; i < length; i++)
			data[this.length++] = value[offset + i];
	}
	
	@Override
	public void writeRawBytes(ByteBuffer data, int length)
	{
		while (this.length + length >= this.data.length)
			expand();
		
		data.get(this.data, this.length, length);
		this.length += length;
	}
	
	@Override
	public void writeAllBytes(ByteBuffer data)
	{
		int size = data.capacity();
		writeVarUInt(size);
		
		data.rewind();
		writeRawBytes(data, size);
	}
	
	@Override
	public void writeBytes(ByteBuffer data, int length)
	{
		writeVarUInt(length);
		writeRawBytes(data, length);
	}
	
	@Override
	public void writeByteArray(byte[] data)
	{
		writeBytes(data);
	}
	
	@Override
	public void writeUByteArray(byte[] data)
	{
		writeBytes(data);
	}
	
	@Override
	public void writeShortArray(short[] data)
	{
		writeVarUInt(data.length);
		for (short element : data)
			writeShort(element);
	}
	
	@Override
	public void writeShortArrayRaw(short[] data)
	{
		for (short element : data)
			writeShort(element);
	}
	
	@Override
	public void writeUShortArray(short[] data)
	{
		writeShortArray(data);
	}
	
	@Override
	public void writeUShortArrayRaw(short[] data)
	{
		writeShortArrayRaw(data);
	}
	
	@Override
	public void writeVarIntArray(int[] data)
	{
		writeVarUInt(data.length);
		for (int element : data)
			writeVarInt(element);
	}
	
	@Override
	public void writeVarIntArrayRaw(int[] data)
	{
		for (int element : data)
			writeVarInt(element);
	}
	
	@Override
	public void writeVarUIntArray(int[] data)
	{
		writeVarUInt(data.length);
		for (int element : data)
			writeVarUInt(element);
	}
	
	@Override
	public void writeVarUIntArrayRaw(int[] data)
	{
		for (int element : data)
			writeVarUInt(element);
	}
	
	@Override
	public void writeIntArray(int[] data)
	{
		writeVarUInt(data.length);
		for (int element : data)
			writeInt(element);
	}
	
	@Override
	public void writeIntArrayRaw(int[] data)
	{
		for (int element : data)
			writeInt(element);
	}
	
	@Override
	public void writeUIntArray(int[] data)
	{
		writeVarUInt(data.length);
		for (int element : data)
			writeUInt(element);
	}
	
	@Override
	public void writeUIntArrayRaw(int[] data)
	{
		for (int element : data)
			writeUInt(element);
	}
	
	@Override
	public void writeVarLongArray(long[] data)
	{
		writeVarUInt(data.length);
		for (long element : data)
			writeVarLong(element);
	}
	
	@Override
	public void writeVarLongArrayRaw(long[] data)
	{
		for (long element : data)
			writeVarLong(element);
	}
	
	@Override
	public void writeVarULongArray(long[] data)
	{
		writeVarUInt(data.length);
		for (long element : data)
			writeVarULong(element);
	}
	
	@Override
	public void writeVarULongArrayRaw(long[] data)
	{
		for (long element : data)
			writeVarULong(element);
	}
	
	@Override
	public void writeLongArray(long[] data)
	{
		writeVarUInt(data.length);
		for (long element : data)
			writeLong(element);
	}
	
	@Override
	public void writeLongArrayRaw(long[] data)
	{
		for (long element : data)
			writeLong(element);
	}
	
	@Override
	public void writeULongArray(long[] data)
	{
		writeVarUInt(data.length);
		for (long element : data)
			writeULong(element);
	}
	
	@Override
	public void writeULongArrayRaw(long[] data)
	{
		for (long element : data)
			writeULong(element);
	}
	
	@Override
	public void writeFloatArray(float[] data)
	{
		writeVarUInt(data.length);
		for (float element : data)
			writeFloat(element);
	}
	
	@Override
	public void writeFloatArrayRaw(float[] data)
	{
		for (float element : data)
			writeFloat(element);
	}
	
	@Override
	public void writeDoubleArray(double[] data)
	{
		writeVarUInt(data.length);
		for (double element : data)
			writeDouble(element);
	}
	
	@Override
	public void writeDoubleArrayRaw(double[] data)
	{
		for (double element : data)
			writeDouble(element);
	}
	
	@Override
	public void writeStringArray(String[] data)
	{
		writeVarUInt(data.length);
		for (String element : data)
			writeString(element);
	}
	
	@Override
	public void writeStringArrayRaw(String[] data)
	{
		for (String element : data)
			writeString(element);
	}
	
	@Override
	public void writeDate(LocalDate value)
	{
		if (value == null) {
			writeVarInt(-32);
			return;
		}
		
		int ivalue = value.getYear() - 2000;
		ivalue = ivalue * 12 * 31
				+ (value.getMonthValue() - 1) * 31
				+ value.getDayOfMonth() - 1;
		writeVarInt(ivalue);
	}
	
	@Override
	public void writeUUID(UUID uuid) {
		writeLong(uuid.getMostSignificantBits());
		writeLong(uuid.getLeastSignificantBits());
	}
	
	@Override
	public void flush()
	{
		// nothing to do
	}
	
	@Override
	public void close()
	{
		
	}
	
	// =======================
	// === Private Methods ===
	// =======================

    private void expand() {
        data = Arrays.copyOf(data, data.length * 2);
    }
}
