/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.ServiceConnector;
import org.openzen.packetstreams.ServiceMeta;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoKeyPair;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.crypto.CryptoSharedKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.crypto.CryptoProvider;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.frames.CloseFrame;
import org.openzen.packetstreams.qpsp.frames.DataFrame;
import org.openzen.packetstreams.qpsp.frames.FinishPacketFrame;
import org.openzen.packetstreams.qpsp.frames.FragEndFrame;
import org.openzen.packetstreams.qpsp.frames.FragPartFrame;
import org.openzen.packetstreams.qpsp.frames.FragStartFrame;
import org.openzen.packetstreams.qpsp.frames.Frame;
import org.openzen.packetstreams.qpsp.frames.FrameQueue;
import org.openzen.packetstreams.qpsp.frames.OpenFrame;
import org.openzen.packetstreams.qpsp.frames.ServiceInfoFrame;
import org.openzen.packetstreams.qpsp.frames.StartFrame;
import org.openzen.packetstreams.qpsp.congestion.StandardCongestionController;
import org.openzen.packetstreams.Connection;
import org.openzen.packetstreams.qpsp.frames.AckFrame;
import org.openzen.packetstreams.qpsp.frames.StopWaitingFrame;

/**
 * Represents a single connection. Can be used both for client and server side
 * connections; as there is no fundamental difference between them.
 */
public class QPSPConnection implements Connection {
	public static final int CONTROL_PRIORITY = Integer.MAX_VALUE;
	
	private static final int MAX_DECRYPTION_ERRORS = 16;
	
	private final QPSPEndpoint endpoint;
	public final Server server;
	public final NetworkLogger logger;
	public InetAddress remoteAddress;
	public int remotePort;
	public final long localID;
	public final long remoteID;
	public int maxPacketSize;
	public int maxBufferSize;
	public final int maxUDPPacketSize = 1200;
	public final boolean isServer;
	private final int sessionTimeout = 30 * 60 * 1000;
	
	public final long remoteNonce;
	public final long localNonce;
	
	private final CryptoPublicKey remotePublicKey;
	private final CryptoPublicKey localPublicKey;
	private final CryptoSharedKey key;
	private final FrameQueue incomingQueue = new FrameQueue();
	private final Queue<FrameData> outgoingQueue = new LinkedList<>();
	private final ControlPacketStream controlStream = new ControlPacketStream();
	
	private final byte[] outgoingNonce;
	private final byte[] incomingNonce;
	
	private final Map<Integer, QPSPStream> streams = new HashMap<>();
	private final StreamMultiplexer multiplexer = new StreamMultiplexer();
	public final PacketScheduler scheduler;
	
	private final Thread commandThread;
	private final BlockingQueue<Runnable> commandQueue = new LinkedBlockingQueue<>();
	private final Timer timer = new Timer();
	private KeepaliveTimerTask keepaliveTimer = null;
	private final Set<Long> receivedPackets = new HashSet<>();
	private long stopWaiting = 0;
	
	private int streamCounter;
	private boolean closed = false;
	private long totalPacketsReceived = 0;
	private long currentSeq = 0;
	
	private long outgoingSeq = 0;
	private long outgoingLossySeq = 0;
	
	private int decryptionErrors = 0;
	
	private int keepaliveInterval = 20000;
	private int maxKeepaliveInterval = 120000;
	private long lastKeepalive = System.currentTimeMillis();
	private long lastPacketReceivedTimestamp = System.currentTimeMillis();
	
	public QPSPConnection(
			QPSPEndpoint endpoint,
			InetAddress address,
			int port,
			long localID,
			long remoteID,
			int maxPacketSize,
			int maxBufferSize,
			long remoteNonce,
			long localNonce,
			Server server,
			boolean isServer,
			CryptoPublicKey remotePublicKey,
			CryptoKeyPair keyPair)
	{
		this.endpoint = endpoint;
		this.server = server;
		this.isServer = isServer;
		this.logger = endpoint.logger;
		
		this.remoteAddress = address;
		this.remotePort = port;
		this.localID = localID;
		this.remoteID = remoteID;
		
		this.remoteNonce = remoteNonce;
		this.localNonce = localNonce;
		this.remotePublicKey = remotePublicKey;
		this.localPublicKey = keyPair.publicKey;
		
		key = endpoint.crypto.createSharedKey(remotePublicKey, keyPair.privateKey);
		
		this.maxPacketSize = maxPacketSize;
		this.maxBufferSize = maxBufferSize;
		
		streamCounter = isServer ? 1 : 0;
		
		outgoingNonce = new byte[CryptoProvider.NONCE_BYTES];
		incomingNonce = new byte[CryptoProvider.NONCE_BYTES];
		setLong(outgoingNonce, 0, remoteNonce);
		setLong(incomingNonce, 0, localNonce);
		
		keepaliveTimer = new KeepaliveTimerTask();
		timer.scheduleAtFixedRate(keepaliveTimer, 1000, 1000);
		timer.scheduleAtFixedRate(new TimerTask() {
			@Override
			public void run() {
				commandQueue.offer(() -> {
					if (!outgoingQueue.isEmpty()) {
						scheduler.resumeStreams();
					}
				});
			}
		}, 50, 50);
		
		commandThread = new Thread(() -> {
			while (!closed) {
				try {
					Runnable command = commandQueue.take();
					if (command != null)
						command.run();
				} catch (InterruptedException ex) {}
			}
		});
		commandThread.start();
		scheduler = new StandardPacketScheduler(this, multiplexer, new StandardCongestionController());
	}
	
	public void runOnNetworkThread(Runnable action) {
		commandQueue.offer(action);
	}
	
	@Override
	public CryptoPublicKey getRemoteKey() {
		return remotePublicKey;
	}
	
	@Override
	public void open(String path, ServiceConnector connector) {
		commandQueue.offer(() -> {
			QPSPStream stream = open(connector);
			stream.open(path, false);
			resume(stream);
		});
	}
	
	@Override
	public void open(String path, ServiceMeta cached, ServiceConnector connector) {
		commandQueue.offer(() -> {
			QPSPStream stream = open(connector);
			stream.open(path, true);
			stream.connect(cached);
			resume(stream);
		});
	}
	
	CryptoProvider getCrypto() {
		return endpoint.crypto;
	}
	
	int getEstimatedRTTInMillis() {
		return scheduler.getEstimatedRTTInMillis();
	}
	
	long getTotalPacketsReceived() {
		return totalPacketsReceived;
	}
	
	long getLastPacketSentTimestamp() {
		return scheduler.getLastSentPacketTimestamp();
	}
	
	public QPSPStream getStream(int streamID) {
		if (!streams.containsKey(streamID))
			streams.put(streamID, new QPSPStream(this, streamID, null));
		
		return streams.get(streamID);
	}
	
	void offer(Frame frame) {
		incomingQueue.offer(frame);
	}
	
	void send(FrameData frame, boolean immediately) {
		outgoingQueue.add(frame);
		multiplexer.resume(controlStream);
		
		if (immediately)
			scheduler.resumeStreams();
	}
	
	public long sendLosslessPacket(byte[] data) throws IOException {
		long seq = outgoingSeq++;
		
		BytesDataOutput output = new BytesDataOutput();
		output.writeVarULong(remoteID);
		encodeCompactedSEQ(output, seq);
		output.writeRawBytes(encrypt(data, remoteID, seq));
		endpoint.send(this, output.toByteArray());
		
		return seq;
	}
	
	public long sendLossyPacket(byte[] data) throws IOException {
		long seq = outgoingLossySeq++;
		
		BytesDataOutput output = new BytesDataOutput();
		output.writeVarULong(remoteID | 1);
		encodeCompactedSEQ(output, seq);
		output.writeRawBytes(encrypt(data, remoteID | 1, seq));
		endpoint.send(this, output.toByteArray());
		
		return seq;
	}
	
	public void stopWaiting(long seq) {
		send(new StopWaitingFrame(this, seq).encode(), false);
	}
	
	private List<NackRange> listNacks(long lastReceived) {
		List<NackRange> result = new ArrayList<>();
		long start = -1;
		int length = 0;
		for (long l = stopWaiting; l < lastReceived; l++) {
			if (receivedPackets.contains(l)) {
				if (length > 0)
					result.add(new NackRange(start, length));
				
				length = 0;
				continue;
			} else if (length == 0) {
				start = l;
			}
			
			length++;
		}
		if (length > 0)
			result.add(new NackRange(start, length));
		
		return result;
	}
	
	void pause() {
		commandQueue.offer(() -> scheduler.pause());
	}
	
	void resume() {
		commandQueue.offer(() -> scheduler.resume());
	}
	
	void close() {
		commandQueue.offer(() -> {
			closed = true;
			scheduler.onConnectionClosed();
		});
	}
	
	void assertOnNetworkThread() {
		if (Thread.currentThread() != commandThread)
			throw new AssertionError();
	}
	
	int getRemoteStreamCounter() {
		return streamCounter;
	}
	
	void initClient(int keepaliveInterval, int maxKeepaliveInterval) {
		this.keepaliveInterval = keepaliveInterval;
		this.maxKeepaliveInterval = maxKeepaliveInterval;
	}
	
	void initServer(int maxPacketSize, int maxBufferSize) {
		this.maxPacketSize = maxPacketSize;
		this.maxBufferSize = maxBufferSize;
	}
	
	void initFromStorage(int streamCounter) {
		this.streamCounter = streamCounter;
	}
	
	void setRemote(InetAddress address, int port) {
		commandQueue.offer(() -> {
			this.remoteAddress = address;
			this.remotePort = port;
		});
	}
	
	public boolean finishPacket(long seq) {
		if (seq < stopWaiting)
			return true;
		if (seq != currentSeq)
			return false;
		
		currentSeq++;
		logger.log(NetworkLogger.CATEGORY_PACKETS, localID, -1, "Finished " + seq);
		return true;
	}
	
	private QPSPStream open(ServiceConnector connector) {
		int streamID = streamCounter;
		streamCounter += 2;
		
		QPSPStream stream = new QPSPStream(
				this,
				streamID,
				connector);
		streams.put(streamID, stream);
		return stream;
	}
	
	void sendInit() {
		commandQueue.offer(() -> {
			try {
				BytesDataOutput output = new BytesDataOutput();
				output.writeVarULong(Constants.CONNECTION_INIT);
				output.writeVarULong(remoteID);
				output.writeULong(localNonce);
				output.writeVarUInt(0); // INIT options, always 0
				output.writeRawBytes(localPublicKey.encode());

				BytesDataOutput encrypted = new BytesDataOutput();
				encrypted.writeVarUInt(0);
				server.getCertificate().serialize(encrypted);
				encrypted.writeVarULong(localID);
				encrypted.writeVarUInt(server.getMaxStreams());
				encrypted.writeVarUInt(server.getMaxPacketSize());
				encrypted.writeVarUInt(server.getMaxBufferSize());
				output.writeByteArray(encrypt(encrypted.toByteArray(), 1, 0));

				endpoint.send(this, output.toByteArray());
			} catch (IOException ex) {
				ex.printStackTrace();
			}
		});
	}
	
	byte[] encrypt(byte[] data, long streamId, long seq) {
		assertOnNetworkThread();
		
		setLong(outgoingNonce, 8, streamId);
		setLong(outgoingNonce, 16, seq);
		return key.encrypt(outgoingNonce, data);
	}
	
	byte[] decrypt(byte[] data, long streamId, long seq) throws CryptoDecryptionException {
		assertOnNetworkThread();
		
		setLong(incomingNonce, 8, streamId);
		setLong(incomingNonce, 16, seq);
		return key.decrypt(incomingNonce, data);
	}
	
	void resume(QPSPStream stream) {
		assertOnNetworkThread();
		
		multiplexer.resume(stream);
		doResume();
	}
	
	void doResume() {
		if (closed)
			return;
		
		scheduler.resumeStreams();
	}
	
	CryptoPublicKey getPublicKey() {
		return localPublicKey;
	}
	
	void onReceived(InetAddress address, int port, BytesDataInput input, boolean lossy) {
		commandQueue.offer(() -> doReceive(address, port, input, lossy));
	}
	
	public long decodeCompactedSEQ(BytesDataInput input) {
		return input.readVarULong();
	}
	
	public int getCompactedSEQLength(long seq) {
		return BytesDataInput.getVarULongLength(seq);
	}
	
	public void encodeCompactedSEQ(BytesDataOutput output, long seq) {
		output.writeVarULong(seq);
	}
	
	private void doReceive(InetAddress address, int port, BytesDataInput input, boolean lossy) {
		totalPacketsReceived++;
		remoteAddress = address;
		remotePort = port;
		
		long seq = decodeCompactedSEQ(input);
		try {
			byte[] data = Arrays.copyOfRange(input.getData(), input.getCurrentOffset(), input.getAvailable());
			byte[] decrypted = decrypt(data, lossy ? localID | 1 : localID, seq);
			onReceived(seq, lossy, decrypted);
			scheduler.onPacketReceived();
			
			Map<QPSPStream, Long> result = incomingQueue.getBlockingSeq();
			for (Map.Entry<QPSPStream, Long> entry : result.entrySet()) {
				logger.log(NetworkLogger.CATEGORY_FRAMES, localID, entry.getKey().id, "Waiting for frame " + entry.getValue());
			}
		} catch (CryptoDecryptionException ex) {
			logger.log(NetworkLogger.CATEGORY_STREAMS, localID, -1, "Decryption error for #" + seq);
			decryptionErrors++;
			if (decryptionErrors > MAX_DECRYPTION_ERRORS)
				; // TODO: go away
		}
	}
	
	private void enqueueAck(long seq) {
		receivedPackets.add(seq);
		List<NackRange> nacks = listNacks(seq);
		Collections.sort(nacks, (a, b) -> Long.compare(b.seq, a.seq));
		
		AckFrame ack = new AckFrame(this, currentSeq, seq, nacks.toArray(new NackRange[nacks.size()]));
		send(ack.encode(), false);
		
		StringBuilder nacksString = new StringBuilder();
		for (int i = 0; i < nacks.size(); i++) {
			if (i > 0) nacksString.append(", ");
			
			NackRange nack = nacks.get(i);
			if (nack.length > 1)
				nacksString.append(nack.seq).append("+").append(nack.length);
			else
				nacksString.append(nack.seq);
		}
		
		if (nacks.size() > 0)
			logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, localID, -1, "-> ACK " + currentSeq + "-" + seq + " NACK " + nacksString);
		else
			logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, localID, -1, "-> ACK " + currentSeq + "-" + seq);
	}
	
	private void onReceived(long seq, boolean lossy, byte[] data) {
		lastPacketReceivedTimestamp = System.currentTimeMillis();
		/*if (!lossy && seq > currentSeq + 64) {
			logger.log(NetworkLogger.CATEGORY_PACKETS, localID, -1, "Packets ");
			return;
		}*/
		
		if (!lossy && seq < currentSeq) {
			logger.log(NetworkLogger.CATEGORY_PACKETS, localID, -1, "Dropping duplicate packet " + seq);
			return;
		}
		
		if (lossy)
			logger.log(NetworkLogger.CATEGORY_PACKETS, localID, -1, "Processing incoming lossy packet " + seq);
		else
			logger.log(NetworkLogger.CATEGORY_PACKETS, localID, -1, "Processing incoming packet " + seq);
		
		try {
			BytesDataInput input = new BytesDataInput(data);
			while (input.hasMore()) {
				int type = input.readUByte();

				switch (type) {
					case Constants.PACKET_OPEN: {
						incomingQueue.offer(OpenFrame.deserializeOpen(input, this));
						break;
					}
					case Constants.PACKET_QUICKOPEN: {
						incomingQueue.offer(OpenFrame.deserializeQuickOpen(input, this));
						break;
					}
					case Constants.PACKET_SERVICEINFO: {
						incomingQueue.offer(ServiceInfoFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_START: {
						incomingQueue.offer(StartFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_DATA: {
						incomingQueue.offer(DataFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_ACK: {
						incomingQueue.offer(AckFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_CLOSE_STREAM: {
						incomingQueue.offer(CloseFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_FRAGSTART: {
						incomingQueue.offer(FragStartFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_FRAGPART: {
						incomingQueue.offer(FragPartFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_FRAGEND: {
						incomingQueue.offer(FragEndFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_STOP_WAITING: {
						incomingQueue.offer(StopWaitingFrame.deserialize(input, this));
						break;
					}
					case Constants.PACKET_CLOSE_CONNECTION: {
						int reason = input.readVarUInt();
						deliverCloseConnection(reason);
						break;
					}
					case Constants.PACKET_INITIALIZED: {
						int maxPacketSize = input.readVarUInt();
						int maxBufferSize = input.readVarUInt();
						logger.log(NetworkLogger.CATEGORY_FRAMES, localID, -1, "<- INITIALIZED");
						initServer(maxPacketSize, maxBufferSize);
						break;
					}
					case Constants.PACKET_KEEPALIVE: {
						logger.log(NetworkLogger.CATEGORY_FRAMES, localID, -1, "<- KEEPALIVE");
						long packetsReceived = input.readVarULong();
						break;
					}
					default:

						break;
				}
			}
		} catch (ArrayIndexOutOfBoundsException ex) {
			logger.log(NetworkLogger.CATEGORY_FRAMES, localID, -1, "Reading past end of stream");
			closeConnection(Constants.CLOSE_PROTOCOL_ERROR);
		}
		
		if (!lossy)
			enqueueAck(seq);
		
		if (!lossy)
			incomingQueue.offer(new FinishPacketFrame(this, seq));
	}
	
	private void closeConnection(int reason) {
		BytesDataOutput output = new BytesDataOutput();
		output.writeUByte(Constants.PACKET_CLOSE_CONNECTION);
		output.writeVarUInt(reason);
		send(new FrameData(Integer.MAX_VALUE, -1, output.toByteArray(), false, false), true);
	}
	
	private void deliverCloseConnection(int reason) {
		if (reason != Constants.CLOSE_REQUESTED_BY_PEER)
			closeConnection(Constants.CLOSE_REQUESTED_BY_PEER);
		
		close();
	}
	
	public void deliverStopWaiting(long seq) {
		if (seq < stopWaiting)
			return;
		
		for (long l = stopWaiting; l < seq; l++)
			receivedPackets.remove(l);
		
		stopWaiting = seq;
		if (stopWaiting > currentSeq)
			currentSeq = stopWaiting;
		
		logger.log(NetworkLogger.CATEGORY_FRAMES, localID, -1, "<- STOP_WAITING " + seq);
	}
	
	public void onClosed(int streamId) {
		scheduler.onStreamClosed(streamId);
	}
	
	public static void setLong(byte[] data, int offset, long value) {
		data[offset + 0] = (byte)(value >> 56);
		data[offset + 1] = (byte)(value >> 48);
		data[offset + 2] = (byte)(value >> 40);
		data[offset + 3] = (byte)(value >> 32);
		data[offset + 4] = (byte)(value >> 24);
		data[offset + 5] = (byte)(value >> 16);
		data[offset + 6] = (byte)(value >> 8);
		data[offset + 7] = (byte)value;
	}
	
	private class KeepaliveTimerTask extends TimerTask {
		@Override
		public void run() {
			commandQueue.offer(() -> {
				long now = System.currentTimeMillis();
				if ((now - getLastPacketSentTimestamp()) < keepaliveInterval
						&& (now - lastKeepalive) < maxKeepaliveInterval)
					return;
				
				if ((now - lastPacketReceivedTimestamp) > sessionTimeout) {
					System.out.println("Connection timed out");
					closeConnection(Constants.CLOSE_CONNECTION_TIMEOUT);
					return;
				}

				lastKeepalive = now;
				BytesDataOutput output = new BytesDataOutput();
				output.writeVarUInt(Constants.PACKET_KEEPALIVE);
				output.writeVarULong(getTotalPacketsReceived());
				send(new FrameData(Integer.MAX_VALUE, 0, output.toByteArray(), true, true), true);
			});
		}
	}
	
	private class ControlPacketStream implements StreamMultiplexer.Stream {
		
		@Override
		public int getId() {
			return -1;
		}

		@Override
		public int getPriority() {
			return CONTROL_PRIORITY;
		}

		@Override
		public FrameData next(int availableSpaceInPacket) {
			return outgoingQueue.poll();
		}
	}
}
