/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.util.LinkedList;
import java.util.Queue;
import java.util.Timer;
import java.util.TimerTask;
import java.util.UUID;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.PacketHints;
import org.openzen.packetstreams.PacketStream;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.Service;
import org.openzen.packetstreams.ServiceConnector;
import org.openzen.packetstreams.ServiceMeta;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.frames.AckFrame;
import org.openzen.packetstreams.qpsp.frames.CloseFrame;
import org.openzen.packetstreams.qpsp.frames.DataFrame;
import org.openzen.packetstreams.qpsp.frames.FragEndFrame;
import org.openzen.packetstreams.qpsp.frames.FragPartFrame;
import org.openzen.packetstreams.qpsp.frames.FragStartFrame;
import org.openzen.packetstreams.qpsp.frames.FrameQueue;
import org.openzen.packetstreams.qpsp.frames.OpenFrame;
import org.openzen.packetstreams.qpsp.frames.ServiceInfoFrame;
import org.openzen.packetstreams.qpsp.frames.StartFrame;
import org.openzen.packetstreams.ServiceStream;
import org.openzen.packetstreams.crypto.CryptoProvider;
import org.openzen.packetstreams.qpsp.frames.FinishPacketFrame;

/**
 * Represents a single stream.
 */
public class QPSPStream implements PacketStream {
	public final QPSPSession session;
	public final NetworkLogger logger;
	public final long localId;
	public long remoteId;
	
	private long outgoingSeq = 0;
	private long outgoingLossySeq = 0;
	
	protected final Timer timer = new Timer();
	private Service service;
	private final ServiceConnector connector;
	private ServiceStream serviceStream;
	
	protected final FrameQueue incomingQueue = new FrameQueue(this);
	private long lastConfirmedSeq = -1;
	private long currentSeq = 0;
	private int currentDataSeq = 0;
	private int priority = 0;
	
	private BytesDataOutput incomingFragment = null;
	
	private final Queue<RawPacket> controlPacketBuffer = new LinkedList<>();
	private byte[] outgoingPacket = null;
	private int outgoingFragmentOffset = 0;
	private boolean sendServiceMeta = false;
	private boolean closed = false;
	
	public QPSPStream(QPSPSession session, long localId, long remoteId, ServiceConnector connector) {
		this.session = session;
		logger = session.logger;
		this.localId = localId;
		this.remoteId = remoteId;
		this.connector = connector;
		
		timer.scheduleAtFixedRate(new TimerTask() {
			@Override
			public void run() {
				if (!controlPacketBuffer.isEmpty() || outgoingPacket != null) {
					//System.out.println("Flush");
					session.resume(QPSPStream.this);
				}
			}
		}, 50, 50);
	}
	
	public void open(String path, boolean quick) {
		BytesDataOutput frame = new BytesDataOutput();
		frame.writeUByte(quick ? Constants.PACKET_QUICKOPEN : Constants.PACKET_OPEN);
		frame.writeString(path);
		enqueueControlFrame(frame.toByteArray());
	}
	
	public void connect(ServiceMeta meta) {
		byte[] init = connector.connect(meta);
		this.priority = meta.defaultPriority; // TODO: allow connectors to specify their desired priority
		
		BytesDataOutput frame = new BytesDataOutput();
		frame.writeUByte(Constants.PACKET_START);
		frame.writeUInt(meta.checksum());
		frame.writeVarInt(priority);
		frame.writeByteArray(init);
		enqueueControlFrame(frame.toByteArray());
		
		serviceStream = connector.onConnected(this);
		serviceStream.onConnected();
		
		session.resume(this);
	}
	
	protected void enqueueControlFrame(byte[] frame) {
		controlPacketBuffer.add(new RawPacket(frame, false, false));
	}
	
	protected void enqueueControlFrame(RawPacket frame) {
		controlPacketBuffer.add(frame);
	}
	
	public Server getServer() {
		return session.server;
	}
	
	public Service getService() {
		return service;
	}
	
	public int getPriority() {
		return priority;
	}
	
	public long decodeCompactedSEQ(BytesDataInput input) {
		// TODO: implement correct algorithm
		return input.readVarULong();
	}
	
	public void encodeCompactedSEQ(BytesDataOutput output, long value) {
		// TODO: implement correct algorithm
		output.writeVarULong(value);
	}
	
	private void enqueueAck(long seq) {
		BytesDataOutput ack = new BytesDataOutput();
		ack.writeVarUInt(Constants.PACKET_ACK);
		encodeCompactedSEQ(ack, seq);
		enqueueControlFrame(new RawPacket(ack.toByteArray(), true, false));
	}
	
	public synchronized void onReceived(long seq, boolean lossy, byte[] data) {
		if (!lossy && seq > currentSeq + 64) {
			logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Too many unprocessed frames! Dropping #" + seq);
			return;
		}
		
		if (!lossy)
			enqueueAck(seq);
		
		if (!lossy && seq < currentSeq) {
			logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Dropping duplicate packet " + seq);
			if (closed)
				session.resume(this); // otherwise the ack doesn't get transmitted
			return;
		}
		
		if (lossy)
			logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Processing incoming lossy packet " + seq);
		else
			logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Processing incoming packet " + seq);
		
		BytesDataInput input = new BytesDataInput(data);
		int dataseq = 0;
		while (input.hasMore()) {
			int type = input.readUByte();
			
			switch (type) {
				case Constants.PACKET_OPEN: {
					String path = input.readString();
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- OPEN " + path);
					incomingQueue.offer(new OpenFrame(seq, path, true));
					break;
				}
				case Constants.PACKET_QUICKOPEN: {
					String path = input.readString();
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- QUICKOPEN " + path);
					incomingQueue.offer(new OpenFrame(seq, path, false));
					break;
				}
				case Constants.PACKET_SERVICEINFO: {
					UUID uuid = input.readUUID();
					int flags = input.readUByte();
					priority = input.readVarInt();
					byte[] serviceInfo = input.readByteArray();
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- SERVICEINFO");
					incomingQueue.offer(new ServiceInfoFrame(new ServiceMeta(uuid, flags, priority, serviceInfo)));
					break;
				}
				case Constants.PACKET_START: {
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- START");
					int checksum = input.readUInt();
					int priority = input.readVarInt();
					byte[] initData = input.readByteArray();
					incomingQueue.offer(new StartFrame(checksum, priority, initData));
					break;
				}
				case Constants.PACKET_DATA: {
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- DATA");
					byte[] packetData = input.readByteArray();
					incomingQueue.offer(new DataFrame(seq, dataseq++, packetData));
					break;
				}
				case Constants.PACKET_ACK: {
					long ackseq = decodeCompactedSEQ(input);
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- ACK " + ackseq);
					session.onAcknowledged(localId, ackseq); // acknowledge to scheduler out-of-order; prevent unnecessary retransmissions
					incomingQueue.offer(new AckFrame(ackseq));
					break;
				}
				case Constants.PACKET_CLOSE: {
					int reason = input.readVarUInt();
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- CLOSE " + reason);
					byte[] closeData = input.readByteArray();
					incomingQueue.offer(new CloseFrame(seq, dataseq++, reason, closeData));
					break;
				}
				case Constants.PACKET_FRAGSTART: {
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- FRAGSTART");
					byte[] fragment = input.readByteArray();
					incomingQueue.offer(new FragStartFrame(seq, dataseq++, fragment));
					break;
				}
				case Constants.PACKET_FRAGPART: {
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- FRAGPART");
					byte[] fragment = input.readByteArray();
					incomingQueue.offer(new FragPartFrame(seq, dataseq++, fragment));
					break;
				}
				case Constants.PACKET_FRAGEND: {
					logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "<- FRAGEND");
					byte[] fragment = input.readByteArray();
					incomingQueue.offer(new FragEndFrame(seq, dataseq++, fragment));
					break;
				}
				default:
					handleUnknown(type, input);
					break;
			}
		}
		
		if (!lossy)
			incomingQueue.offer(new FinishPacketFrame(seq));
	}
	
	@Override
	public synchronized void resume() {
		session.resume(this);
	}

	@Override
	public synchronized void close(int reason, byte[] info) {
		BytesDataOutput close = new BytesDataOutput();
		close.writeUByte(Constants.PACKET_CLOSE);
		close.writeVarUInt(reason);
		close.writeByteArray(info);
		enqueueControlFrame(close.toByteArray());
		session.resume(this);
	}
	
	@Override
	public CryptoProvider getCrypto() {
		return session.getCrypto();
	}
	
	@Override
	public int getEstimatedRTTInMillis() {
		return session.getEstimatedRTTInMillis();
	}
	
	public void handleUnknown(int type, BytesDataInput input) {
		close(Constants.CLOSE_PROTOCOL_ERROR, new byte[0]);
	}
	
	public void setService(Service service) {
		this.service = service;
	}
	
	public void transmitServiceMeta() {
		sendServiceMeta = true;
		session.resume(this);
	}
	
	public void start(int priority, byte[] initData) {
		serviceStream = service.open(this, initData);
		serviceStream.onConnected();
		this.priority = priority;
	}
	
	public boolean isStarted() {
		return serviceStream != null;
	}
	
	public void deliverServiceInfo(ServiceMeta meta) {
		connect(meta);
	}
	
	public void deliver(long seq, int dataseq, byte[] packet) {
		if (seq != currentSeq || dataseq != currentDataSeq) {
			session.logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Preventing duplicate delivery of " + seq + "::" + dataseq);
			return;
		}
		
		serviceStream.onReceived(packet);
		currentDataSeq++;
	}
	
	public void deliverClose(long seq, int dataseq, int reason, byte[] info) {
		if (seq != currentSeq && dataseq != currentDataSeq) {
			session.logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Preventing duplicate delivery of " + seq + "::" + dataseq);
			return;
		}
		
		if (reason != Constants.CLOSE_REQUESTED_BY_PEER)
			close(Constants.CLOSE_REQUESTED_BY_PEER, new byte[0]);
		
		timer.cancel();
		service = null;
		currentDataSeq++;
		closed = true;
		
		session.resume(this); // makes sure the final ack gets transmitted
		session.onClosed(localId);
		if (serviceStream != null)
			serviceStream.onConnectionClosed(reason, info);
	}
	
	public boolean finishPacket(long seq) {
		if (seq < currentSeq)
			return true; // duplicate packet; skip
		if (seq != currentSeq)
			return false;
		
		logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Finished " + seq);
		currentSeq++;
		currentDataSeq = 0;
		return true;
	}
	
	public void beginFragmented(byte[] data) {
		if (incomingFragment != null) {
			close(Constants.CLOSE_PROTOCOL_ERROR, new byte[0]);
			return;
		}
		
		incomingFragment = new BytesDataOutput();
		incomingFragment.writeRawBytes(data);
	}
	
	public void appendFragmented(byte[] data) {
		if (incomingFragment == null) {
			close(Constants.CLOSE_PROTOCOL_ERROR, new byte[0]);
			return;
		}
		
		incomingFragment.writeRawBytes(data);
	}
	
	public void finishFragmented(long seq, int dataseq, byte[] data) {
		if (incomingFragment == null) {
			close(Constants.CLOSE_PROTOCOL_ERROR, new byte[0]);
			return;
		}
		
		BytesDataInput input = new BytesDataInput(data);
		int fragmentType = input.readUByte();
		switch (fragmentType) {
			case Constants.PACKET_SERVICEINFO: {
					UUID uuid = input.readUUID();
					int flags = input.readUByte();
					int defaultPriority = input.readVarInt();
					byte[] serviceInfo = input.readByteArray();
					deliverServiceInfo(new ServiceMeta(uuid, flags, defaultPriority, serviceInfo));
					break;
				}
				case Constants.PACKET_START: {
					int checksum = input.readVarUInt();
					int priority = input.readVarUInt();
					byte[] initData = input.readByteArray();
					incomingQueue.offer(new StartFrame(checksum, priority, initData));
					break;
				}
				case Constants.PACKET_DATA: {
					byte[] packetData = input.readByteArray();
					deliver(seq, dataseq, packetData);
					break;
				}
				case Constants.PACKET_CLOSE: {
					int reason = input.readVarUInt();
					byte[] closeData = input.readByteArray();
					deliverClose(seq, dataseq, reason, data);
					break;
				}
				default:
					close(Constants.CLOSE_PROTOCOL_ERROR, new byte[0]);
					break;
		}
		incomingFragment = null;
	}
	
	public boolean ack(long seq) {
		if (seq >= lastConfirmedSeq)
			return true; // duplicate ack
		
		if (seq == lastConfirmedSeq - 1) {
			lastConfirmedSeq++;
			return true;
		} else {
			return false;
		}
	}
	
	public boolean hasReached(long seq, int dataseq) {
		return (seq == currentSeq && dataseq == currentDataSeq) || (seq < currentSeq);
	}
	
	public QPSPPacket next() {
		if (remoteId == -1) {
			logger.log(NetworkLogger.CATEGORY_FRAMES, localId, "Remote id not yet known, delaying transmission");
			return null; // can't transmit until remote id received
		}
		
		RawPacket nextRaw = nextRaw();
		if (nextRaw == null)
			return null;
		
		long seq = nextRaw.lossy ? outgoingLossySeq++ : outgoingSeq++;
		long streamId = nextRaw.lossy ? (remoteId | 2) : remoteId;
		byte[] encrypted = session.encrypt(nextRaw.data, streamId, seq);
		
		BytesDataOutput packet = new BytesDataOutput();
		packet.writeVarULong(streamId);
		encodeCompactedSEQ(packet, seq);
		packet.writeRawBytes(encrypted);
		return new QPSPPacket(priority, nextRaw.lossy ? (localId | 2) : localId, seq, packet.toByteArray(), nextRaw.keepalive);
	}
	
	protected class RawPacket {
		public final byte[] data;
		public final boolean lossy;
		public final boolean keepalive;
		
		public RawPacket(byte[] data, boolean lossy, boolean keepalive) {
			this.data = data;
			this.lossy = lossy;
			this.keepalive = keepalive;
		}
	}
	
	private RawPacket nextRaw() {
		boolean lossy = true;
		boolean hasData = false;
		boolean keepalive = false;
		int available = session.maxUDPPacketSize;
		BytesDataOutput output = new BytesDataOutput();
		while (!controlPacketBuffer.isEmpty()) {
			RawPacket next = controlPacketBuffer.peek();
			if (next == null) // TODO: this is actually a bug (multithreading?)
				break;
			
			lossy &= next.lossy;
			keepalive |= next.keepalive;
			if (output.length() + next.data.length < session.maxUDPPacketSize) {
				controlPacketBuffer.poll();
				available -= next.data.length;
				output.writeRawBytes(next.data);
				hasData = true;
			} else {
				if (!hasData)
					throw new IllegalStateException("Control packet doesn't fit in the output; length: " + next.data.length);
				
				return new RawPacket(output.toByteArray(), lossy, keepalive);
			}
		}
		
		if (outgoingPacket == null && sendServiceMeta) {
			lossy = false;
			BytesDataOutput meta = new BytesDataOutput();
			ServiceMeta metadata = service.getMeta();
			meta.writeUByte(Constants.PACKET_SERVICEINFO);
			meta.writeUUID(metadata.uuid);
			meta.writeUByte(metadata.flags);
			meta.writeVarInt(metadata.defaultPriority);
			meta.writeByteArray(metadata.serviceInfo);
			outgoingPacket = meta.toByteArray();
			outgoingFragmentOffset = 0;
			sendServiceMeta = false;
		}
		
		while (available > 32) {
			if (outgoingPacket == null && serviceStream != null) {
				outgoingPacket = nextData(Math.max(512, available - 5));
				outgoingFragmentOffset = 0;
			}
			
			if (outgoingPacket != null) {
				lossy = false;
				hasData = true;
				if (outgoingFragmentOffset != 0 || outgoingPacket.length + 5 > available) {
					available -= nextFragment(output, available);
				} else {
					output.writeRawBytes(outgoingPacket);
					available -= outgoingPacket.length;
					outgoingPacket = null;
				}
			} else {
				break;
			}
		}
		
		if (!hasData)
			return null;
		
		return new RawPacket(output.toByteArray(), lossy, keepalive);
	}
	
	private byte[] nextData(int availableSpaceInPacket) {
		int recommendedLength = availableSpaceInPacket - getSizeLength(availableSpaceInPacket) - 1;
		byte[] packet = serviceStream.next(new PacketHints(recommendedLength));
		if (packet == null)
			return null;
		
		BytesDataOutput output = new BytesDataOutput();
		output.writeUByte(Constants.PACKET_DATA);
		output.writeByteArray(packet);
		return output.toByteArray();
	}
	
	private int getSizeLength(int size) {
		if (size < (1 << 7))
			return 1;
		else if (size < (1 << 14))
			return 2;
		else if (size < (1 << 21))
			return 3;
		else if (size < (1 << 28))
			return 4;
		else
			return 5;
	}
	
	private int nextFragment(BytesDataOutput output, int size) {
		if (outgoingFragmentOffset == 0) {
			output.writeUByte(Constants.PACKET_FRAGSTART);
			output.writeBytes(outgoingPacket, 0, size);
			outgoingFragmentOffset += size;
			return size + 1;
		} else if (outgoingPacket.length - outgoingFragmentOffset > size) {
			output.writeUByte(Constants.PACKET_FRAGPART);
			output.writeBytes(outgoingPacket, outgoingFragmentOffset, size);
			outgoingFragmentOffset += size;
			return size + 1;
		} else {
			output.writeUByte(Constants.PACKET_FRAGEND);
			output.writeBytes(outgoingPacket, outgoingFragmentOffset, outgoingPacket.length - outgoingFragmentOffset);
			
			int written = outgoingPacket.length - outgoingFragmentOffset + 1;
			outgoingPacket = null;
			return written;
		}
	}
}
