/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.net.InetAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.ServiceConnector;
import org.openzen.packetstreams.ServiceMeta;
import org.openzen.packetstreams.Session;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoKeyPair;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.crypto.CryptoSharedKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.crypto.CryptoProvider;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.scheduler.StandardCongestionController;

/**
 * Represents a single session. Can be used both for client and server side
 * sessions; as there is no fundamental difference between them.
 */
public class QPSPSession implements Session {
	private static final int MAX_DECRYPTION_ERRORS = 16;
	
	private final QPSPEndpoint endpoint;
	public final Server server;
	public final NetworkLogger logger;
	public InetAddress remoteAddress;
	public int remotePort;
	public final long localFromStreamID;
	public final long localToStreamID;
	public long remoteFromStreamID;
	public long remoteToStreamID;
	public int maxPacketSize;
	public int maxBufferSize;
	public final int maxUDPPacketSize = 1200;
	
	public final long remoteNonce;
	public final long localNonce;
	
	private final CryptoPublicKey remotePublicKey;
	private final CryptoKeyPair keyPair;
	private final CryptoSharedKey key;
	
	private final byte[] outgoingNonce;
	private final byte[] incomingNonce;
	
	private final Map<Long, QPSPStream> streams = new HashMap<>();
	private final ControlStream controlStream;
	private final StreamMultiplexer multiplexer;
	private final PacketScheduler scheduler;
	
	private final Thread commandThread;
	public final BlockingQueue<Runnable> commandQueue = new LinkedBlockingQueue<>();
	
	private long remoteStreamCounter;
	private boolean closed = false;
	private long totalPacketsReceived = 0;
	
	private int decryptionErrors = 0;
	
	public QPSPSession(
			QPSPEndpoint endpoint,
			InetAddress address,
			int port,
			long localFromStreamID,
			long localToStreamID,
			int maxPacketSize,
			int maxBufferSize,
			long remoteNonce,
			long localNonce,
			Server server,
			CryptoPublicKey remotePublicKey,
			CryptoKeyPair keyPair)
	{
		this.endpoint = endpoint;
		this.server = server;
		this.logger = endpoint.logger;
		
		this.remoteAddress = address;
		this.remotePort = port;
		this.localFromStreamID = localFromStreamID;
		this.localToStreamID = localToStreamID;
		
		this.remoteFromStreamID = -1;
		this.remoteToStreamID = -1;
		
		this.remoteNonce = remoteNonce;
		this.localNonce = localNonce;
		this.remotePublicKey = remotePublicKey;
		this.keyPair = keyPair;
		
		key = endpoint.crypto.createSharedKey(remotePublicKey, keyPair.privateKey);
		multiplexer = new StreamMultiplexer(this);
		
		this.maxPacketSize = maxPacketSize;
		this.maxBufferSize = maxBufferSize;
		
		outgoingNonce = new byte[CryptoProvider.NONCE_BYTES];
		incomingNonce = new byte[CryptoProvider.NONCE_BYTES];
		setLong(outgoingNonce, 0, remoteNonce);
		setLong(incomingNonce, 0, localNonce);
		
		controlStream = new ControlStream(this, localFromStreamID, -1);
		streams.put(localFromStreamID, controlStream); // remote stream ID will be filled later
		
		commandThread = new Thread(() -> {
			while (!closed) {
				try {
					Runnable command = commandQueue.take();
					if (command != null)
						command.run();
				} catch (InterruptedException ex) {}
			}
		});
		commandThread.start();
		scheduler = new StandardPacketScheduler(this, multiplexer, endpoint, new StandardCongestionController());
	}
	
	@Override
	public CryptoPublicKey getRemoteKey() {
		return remotePublicKey;
	}
	
	@Override
	public void open(String path, ServiceConnector connector) {
		commandQueue.offer(() -> {
			QPSPStream stream = open(connector);
			stream.open(path, false);
			resume(stream);
		});
	}
	
	@Override
	public void open(String path, ServiceMeta cached, ServiceConnector connector) {
		commandQueue.offer(() -> {
			QPSPStream stream = open(connector);
			stream.open(path, true);
			stream.connect(cached);
			resume(stream);
		});
	}
	
	public CryptoProvider getCrypto() {
		return endpoint.crypto;
	}
	
	public int getEstimatedRTTInMillis() {
		return scheduler.getEstimatedRTTInMillis();
	}
	
	public long getTotalPacketsReceived() {
		return totalPacketsReceived;
	}
	
	public long getLastPacketSentTimestamp() {
		return scheduler.getLastSentPacketTimestamp();
	}
	
	private QPSPStream open(ServiceConnector connector) {
		long remoteStreamId = remoteStreamCounter;
		long localStreamId = remoteStreamId - remoteFromStreamID + localFromStreamID;
		remoteStreamCounter += 4;
		
		QPSPStream stream = new QPSPStream(
				this,
				localStreamId,
				remoteStreamId,
				connector);
		streams.put(localStreamId, stream);
		return stream;
	}
	
	public void pause() {
		commandQueue.offer(() -> scheduler.pause());
	}
	
	public void resume() {
		commandQueue.offer(() -> scheduler.resume());
	}
	
	public void close() {
		commandQueue.offer(() -> {
			closed = true;
			scheduler.onSessionClosed();
		});
	}
	
	public void assertOnNetworkThread() {
		if (Thread.currentThread() != commandThread)
			throw new AssertionError();
	}
	
	public long getRemoteStreamCounter() {
		return remoteStreamCounter;
	}
	
	public void initClient(
			long remoteFromStreamId,
			long remoteToStreamId,
			int keepaliveInterval,
			int maxKeepaliveInterval) {
		this.remoteFromStreamID = remoteFromStreamId;
		this.remoteToStreamID = remoteToStreamId;
		this.remoteStreamCounter = remoteFromStreamID + 4;
		
		controlStream.initialize(keepaliveInterval, maxKeepaliveInterval);
		System.out.println("Init client session; remote " + remoteFromStreamId + ", local " + localFromStreamID);
		
		for (QPSPStream stream : streams.values()) {
			stream.remoteId = stream.localId - localFromStreamID + remoteFromStreamID;
		}
	}
	
	public void initServer(
			long remoteFromStreamId,
			long remoteToStreamId,
			int maxPacketSize,
			int maxBufferSize) {
		this.remoteFromStreamID = remoteFromStreamId;
		this.remoteToStreamID = remoteToStreamId;
		this.remoteStreamCounter = remoteFromStreamId + 4;
		this.maxPacketSize = maxPacketSize;
		this.maxBufferSize = maxBufferSize;
		System.out.println("Init server session; remote " + remoteFromStreamId + ", local " + localFromStreamID);
		
		for (QPSPStream stream : streams.values()) {
			stream.remoteId = stream.localId - localFromStreamID + remoteFromStreamId;
		}
	}
	
	public void initFromStorage(
			long remoteFromStreamId,
			long remoteToStreamId,
			long remoteStreamCounter) {
		this.remoteFromStreamID = remoteFromStreamId;
		this.remoteToStreamID = remoteToStreamId;
		this.remoteStreamCounter = remoteStreamCounter;
		
		for (QPSPStream stream : streams.values()) {
			stream.remoteId = stream.localId - localFromStreamID + remoteFromStreamId;
		}
	}
	
	public void setRemote(InetAddress address, int port) {
		commandQueue.offer(() -> {
			this.remoteAddress = address;
			this.remotePort = port;
		});
	}
	
	public void sendInit() {
		commandQueue.offer(() -> {
			try {
				BytesDataOutput output = new BytesDataOutput();
				output.writeVarULong(Constants.STREAM_INIT);
				output.writeULong(remoteNonce);
				output.writeULong(localNonce);
				output.writeVarUInt(0); // INIT options, always 0
				output.writeRawBytes(keyPair.publicKey.encode());

				BytesDataOutput encrypted = new BytesDataOutput();
				encrypted.writeVarUInt(0);
				server.getCertificate().serialize(encrypted);
				encrypted.writeVarULong(localFromStreamID);
				encrypted.writeVarULong((localToStreamID - localFromStreamID) / 4);
				encrypted.writeVarUInt(server.getMaxPacketSize());
				encrypted.writeVarUInt(server.getMaxBufferSize());
				output.writeByteArray(encrypt(encrypted.toByteArray(), 0, 0));

				endpoint.send(this, output.toByteArray());
			} catch (IOException ex) {
				ex.printStackTrace();
			}
		});
	}
	
	public byte[] encrypt(byte[] data, long streamId, long seq) {
		return encrypt(data, 0, data.length, streamId, seq);
	}
	
	public byte[] encrypt(byte[] data, int offset, int length, long streamId, long seq) {
		assertOnNetworkThread();
		
		setLong(outgoingNonce, 8, streamId);
		setLong(outgoingNonce, 16, seq);
		return key.encrypt(outgoingNonce, data, offset, length);
	}
	
	public byte[] decrypt(byte[] data, int offset, int length, long streamId, long seq) throws CryptoDecryptionException {
		assertOnNetworkThread();
		
		setLong(incomingNonce, 8, streamId);
		setLong(incomingNonce, 16, seq);
		return key.decrypt(incomingNonce, data, offset, length);
	}
	
	public void resume(QPSPStream stream) {
		assertOnNetworkThread();
		
		multiplexer.resume(stream);
	}
	
	public void doResume() {
		if (closed)
			return;
		
		scheduler.resumeStreams();
	}
	
	public boolean equals(QPSPSession other) {
		return remoteAddress.equals(other.remoteAddress) && remotePort == other.remotePort;
	}
	
	public CryptoPublicKey getPublicKey() {
		return keyPair.publicKey;
	}
	
	public void onReceived(long streamId, BytesDataInput input) {
		commandQueue.offer(() -> doReceive(streamId, input));
	}
	
	private void doReceive(long streamId, BytesDataInput input) {
		totalPacketsReceived++;
		
		QPSPStream stream = getLocalStream(streamId & ~2);
		long seq = stream.decodeCompactedSEQ(input);
		try {
			byte[] decrypted = decrypt(input.getData(), input.getCurrentOffset(), input.getAvailable(), streamId, seq);
			stream.onReceived(seq, (streamId & 2) == 2, decrypted);
			scheduler.onPacketReceived();
		} catch (CryptoDecryptionException ex) {
			logger.log(NetworkLogger.CATEGORY_STREAMS, streamId, "Decryption error for #" + seq);
			decryptionErrors++;
			if (decryptionErrors > MAX_DECRYPTION_ERRORS)
				; // TODO: go away
		}
	}
	
	public void onAcknowledged(long streamId, long seq) {
		scheduler.onAcknowledged(streamId, seq);
	}
	
	public void onClosed(long streamId) {
		scheduler.onStreamClosed(streamId);
	}
	
	private QPSPStream getLocalStream(long streamId) {
		QPSPStream stream = streams.get(streamId);
		if (stream == null) {
			logger.log(NetworkLogger.CATEGORY_SETUP, streamId, "New stream created for stream " + streamId);
			
			long remoteStreamId = remoteFromStreamID == -1 ? -1 : streamId - localFromStreamID + remoteFromStreamID;
			stream = new QPSPStream(this, streamId, remoteStreamId, null);
			streams.put(streamId, stream);
		}
		
		return stream;
	}
	
	public static void setLong(byte[] data, int offset, long value) {
		data[offset + 0] = (byte)(value >> 56);
		data[offset + 1] = (byte)(value >> 48);
		data[offset + 2] = (byte)(value >> 40);
		data[offset + 3] = (byte)(value >> 32);
		data[offset + 4] = (byte)(value >> 24);
		data[offset + 5] = (byte)(value >> 16);
		data[offset + 6] = (byte)(value >> 8);
		data[offset + 7] = (byte)value;
	}
}
