/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import org.openzen.packetstreams.NullLogger;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.SocketException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Timer;
import java.util.TimerTask;
import java.util.TreeMap;
import org.openzen.packetstreams.ClientSession;
import org.openzen.packetstreams.EmptyHost;
import org.openzen.packetstreams.EmptyServer;
import org.openzen.packetstreams.Host;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.SigningRootValidator;
import org.openzen.packetstreams.crypto.CertificateChain;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.socket.PureUDPSocket;
import org.openzen.packetstreams.qpsp.socket.UDPSocket;
import org.openzen.packetstreams.crypto.CryptoProvider;

/**
 * An endpoint represents a single client or server, connecting and accepting
 * data through a single UDP port.
 */
public class QPSPEndpoint {
	private static final int SETUP_TIMEOUT = 5000;
	
	public final CryptoProvider crypto;
	private final Host host;
	
	private final UDPSocket socket;
	public final NetworkLogger logger;
	private volatile boolean closed = false;
	
	private final Random random = new SecureRandom();
	private final TreeMap<Long, QPSPSession> sessionsByStream = new TreeMap<>();
	private final Map<SessionKey, QPSPSession> sessionsByRequest = new HashMap<>();
	private final Map<Long, QPSPClientSession> requestedSessions = new HashMap<>();
	
	private final Timer setupRetransmitTimer = new Timer();
	private final List<SetupPacket> setups = new ArrayList<>();
	private long streamCounter = 4;
	
	public QPSPEndpoint(UDPSocket socket, NetworkLogger logger, Host host, CryptoProvider crypto) {
		this.host = host;
		this.logger = logger;
		this.crypto = crypto;
		this.socket = socket;
	}
	
	public QPSPEndpoint(CryptoProvider crypto) {
		this(1200, EmptyHost.INSTANCE, crypto);
	}
	
	public QPSPEndpoint(Host host, CryptoProvider crypto) {
		this(1200, host, crypto);
	}
	
	public QPSPEndpoint(int port, CryptoProvider crypto) {
		this(port, EmptyHost.INSTANCE, crypto);
	}
	
	public QPSPEndpoint(int port, Host host, CryptoProvider crypto) {
		this(new PureUDPSocket(port), NullLogger.INSTANCE, host, crypto);
	}
	
	public void open() {
		try {
			socket.open();
			closed = false;
			new Receptor().start();
			
			for (QPSPSession session : sessionsByStream.values())
				session.resume();
		} catch (SocketException ex) {
			logger.log(NetworkLogger.CATEGORY_SETUP, 0, ex.getMessage());
		}
	}
	
	public void pause() {
		closed = true;
		
		for (QPSPSession session : sessionsByStream.values())
			session.pause();
		
		socket.close();
	}
	
	public void close() {
		closed = true;
		
		for (QPSPSession session : sessionsByStream.values())
			session.close();
		
		socket.close();
	}
	
	public ClientSession connect(String host, SigningRootValidator rootValidator) throws IOException {
		return connect(host, 1200, new EmptyServer(crypto.generateKeyPair()), rootValidator, 20000, 120000);
	}
	
	public ClientSession connect(String host, int port, Server local, SigningRootValidator rootValidator, int keepaliveInterval, int maxKeepaliveInterval)
			throws IOException {
		long clientNonce = random.nextLong();
		
		BytesDataOutput output = new BytesDataOutput();
		output.writeVarUInt(Constants.STREAM_SETUP);
		output.writeULong(clientNonce);
		output.writeVarUInt(1); // protocol version
		output.writeVarUInt(0); // protocol options
		output.writeString(host);
		output.writeRawBytes(local.getKeyPair().publicKey.encode());
		
		byte[] packetData = output.toByteArray();
		DatagramPacket packet = new DatagramPacket(packetData, packetData.length);
		packet.setAddress(InetAddress.getByName(host));
		packet.setPort(port);
		SetupPacket setup = new SetupPacket(packet, clientNonce);
		setups.add(setup);
		setupRetransmitTimer.scheduleAtFixedRate(setup, SETUP_TIMEOUT, SETUP_TIMEOUT);
		
		QPSPClientSession result = new QPSPClientSession(this, local, host, port, local.getKeyPair(), clientNonce, rootValidator, keepaliveInterval, maxKeepaliveInterval);
		requestedSessions.put(clientNonce, result);
		
		socket.send(packet);
		return result;
	}
	
	public long allocateStreams(int streams) {
		long result = streamCounter;
		streamCounter += streams * 4;
		return result;
	}
	
	private class SetupPacket extends TimerTask {
		private final DatagramPacket packet;
		private final long nonce;
		
		public SetupPacket(DatagramPacket packet, long nonce) {
			this.packet = packet;
			this.nonce = nonce;
		}

		@Override
		public void run() {
			try {
				logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, 0, "Retransmitting SETUP");
				socket.send(packet);
			} catch (IOException ex) {}
		}
	}
	
	public void send(QPSPSession channel, byte[] packetData) throws IOException {
		DatagramPacket packet = new DatagramPacket(packetData, packetData.length);
		packet.setAddress(channel.remoteAddress);
		packet.setPort(channel.remotePort);
		socket.send(packet);
	}
	
	private class Receptor extends Thread {
		@Override
		public void run() {
			try {
				while (!closed) {
					DatagramPacket packet = socket.receive();
					onReceived(packet);
				}
			} catch (SocketException ex) {
				if (ex.getMessage().equals("socket closed")) // TODO: a better way to handle this
					return;
				
				ex.printStackTrace();
			} catch (IOException ex) {
				ex.printStackTrace();
			}
		}
	}
	
	private void onReceived(DatagramPacket packet) throws IOException {
		BytesDataInput input = new BytesDataInput(packet.getData(), packet.getOffset(), packet.getLength());
		long stream = input.readVarULong();
		if (stream == Constants.STREAM_SETUP) {
			handleSetup(packet, input);
		} else if (stream == Constants.STREAM_INIT) {
			handleInit(input);
		} else {
			Entry<Long, QPSPSession> sessionEntry = sessionsByStream.floorEntry(stream);
			if (sessionEntry == null || stream >= sessionEntry.getValue().localToStreamID) {
				// unknown session
			} else {
				QPSPSession session = sessionEntry.getValue();
				session.onReceived(stream, input);
			}
		}
	}
	
	private void handleSetup(DatagramPacket packet, BytesDataInput input) {
		long clientNonce = input.readULong();

		int protocolVersion = input.readVarUInt();
		if (protocolVersion != Constants.VERSION)
			return; // rejected version

		int protocolFlags = input.readUByte(); // ignored
		byte[] domainNameBytes = input.readBytes();
		String domainName = new String(domainNameBytes, StandardCharsets.UTF_8);
		byte[] clientPublicKey = input.readRawBytes(CryptoProvider.PUBLIC_KEY_BYTES);
		logger.log(NetworkLogger.CATEGORY_SETUP, Constants.STREAM_SETUP, "SETUP " + domainName);
		
		SessionKey key = new SessionKey(
				clientNonce,
				protocolVersion,
				protocolFlags,
				domainNameBytes,
				clientPublicKey);
		if (sessionsByRequest.containsKey(key)) {
			QPSPSession session = sessionsByRequest.get(key);
			session.setRemote(packet.getAddress(), packet.getPort());
			session.sendInit();
			return;
		}
		
		Server server = host.getServer(domainName);
		if (server == null)
			return; // no such server
		
		long fromStreamID = streamCounter;
		int streams = server.getSessionStreamCount();
		streamCounter += 2 * streams;
		QPSPSession session = new QPSPSession(
				this, 
				packet.getAddress(),
				packet.getPort(),
				fromStreamID,
				streamCounter,
				1024,
				4096,
				clientNonce,
				random.nextLong(),
				server,
				crypto.decodePublicKey(clientPublicKey),
				server.getKeyPair());

		sessionsByStream.put(session.localFromStreamID, session);
		sessionsByRequest.put(key, session);
		session.sendInit();
	}
	
	private void handleInit(BytesDataInput input) {
		logger.log(NetworkLogger.CATEGORY_SETUP, Constants.STREAM_INIT, "INIT");
		
		long clientNonce = input.readULong();
		long serverNonce = input.readULong();
		int options = input.readVarUInt();
		CryptoPublicKey serverPublicKey = crypto.decodePublicKey(input.readRawBytes(CryptoProvider.PUBLIC_KEY_BYTES));
		
		SetupPacket setup = null; // TODO - this can be optimized
		for (SetupPacket packet : setups)
			if (packet.nonce == clientNonce)
				setup = packet;
		
		if (setup != null) {
			setup.cancel();
			setups.remove(setup);
		}
					
		QPSPClientSession session = requestedSessions.remove(clientNonce);
		if (session == null)
			return; // ignore
		
		session.preInit(serverNonce, serverPublicKey);
		try {
			byte[] decrypted = session.decryptInit(input.readByteArray());
			BytesDataInput decryptedInput = new BytesDataInput(decrypted);
			int flags = decryptedInput.readVarUInt();
			CertificateChain certificate = new CertificateChain(crypto, decryptedInput);
			long fromStreamId = decryptedInput.readVarULong();
			long toStreamId = fromStreamId + decryptedInput.readVarUInt() * 4L;
			int maxPacketSize = decryptedInput.readVarUInt();
			int maxBufferSize = decryptedInput.readVarUInt();

			if (!session.isValidSigningRoot(certificate.rootKey))
				return; // TODO: report error
			if (!certificate.validate(session.host, serverPublicKey))
				return; // TODO: report error

			QPSPSession qpsqSession = session.init(fromStreamId, toStreamId, maxPacketSize, maxBufferSize);
			sessionsByStream.put(fromStreamId, qpsqSession);
		} catch (CryptoDecryptionException ex) {
			logger.log(NetworkLogger.CATEGORY_SETUP, -1, "Crypto exception on INIT packet");
			return;
		}
	}
	
	private static final class SessionKey {
		private final long clientNonce;
		private final int protocolVersion;
		private final int protocolFlags;
		private final byte[] domainName;
		private final byte[] clientPublicKey;
		
		public SessionKey(
				long clientNonce,
				int protocolVersion,
				int protocolFlags,
				byte[] domainName,
				byte[] clientPublicKey)
		{
			this.clientNonce = clientNonce;
			this.protocolVersion = protocolVersion;
			this.protocolFlags = protocolFlags;
			this.domainName = domainName;
			this.clientPublicKey = clientPublicKey;
		}

		@Override
		public int hashCode() {
			int hash = 3;
			hash = 97 * hash + (int) (this.clientNonce ^ (this.clientNonce >>> 32));
			hash = 97 * hash + this.protocolVersion;
			hash = 97 * hash + this.protocolFlags;
			hash = 97 * hash + Arrays.hashCode(this.domainName);
			hash = 97 * hash + Arrays.hashCode(this.clientPublicKey);
			return hash;
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (getClass() != obj.getClass())
				return false;
			
			final SessionKey other = (SessionKey) obj;
			return this.clientNonce == other.clientNonce
					&& this.protocolVersion == other.protocolVersion
					&& this.protocolFlags == other.protocolFlags
					&& Arrays.equals(this.domainName, other.domainName)
					&& Arrays.equals(this.clientPublicKey, other.clientPublicKey);
		}
	}
}
