/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import org.openzen.packetstreams.NullLogger;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.SocketException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Timer;
import java.util.TimerTask;
import org.openzen.packetstreams.ConnectionListener;
import org.openzen.packetstreams.EmptyServer;
import org.openzen.packetstreams.Host;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.SigningRootValidator;
import org.openzen.packetstreams.crypto.CertificateChain;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.socket.PureUDPSocket;
import org.openzen.packetstreams.qpsp.socket.UDPSocket;
import org.openzen.packetstreams.crypto.CryptoProvider;

/**
 * An endpoint represents a single client or server, connecting and accepting
 * data through a single UDP port.
 */
public class QPSPEndpoint {
	private static final int SETUP_TIMEOUT = 5000;
	
	public final CryptoProvider crypto;
	private final Host host;
	
	private final UDPSocket socket;
	public final NetworkLogger logger;
	private volatile boolean closed = false;
	
	private final Random random = new SecureRandom();
	private final Map<Long, QPSPConnection> connectionsByID = new HashMap<>();
	private final Map<Long, QPSPConnectionRequest> requestedSessions = new HashMap<>();
	private final Map<SessionKey, QPSPConnection> connectionsByKey = new HashMap<>();
	
	private final Timer setupRetransmitTimer = new Timer();
	private final List<SetupPacket> setups = new ArrayList<>();
	private long connectionCounter = 4;
	
	public QPSPEndpoint(UDPSocket socket, NetworkLogger logger, Host host, CryptoProvider crypto) {
		this.host = host;
		this.logger = logger;
		this.crypto = crypto;
		this.socket = socket;
	}
	
	public QPSPEndpoint(CryptoProvider crypto) {
		this(1200, Host.empty(), crypto);
	}
	
	public QPSPEndpoint(Host host, CryptoProvider crypto) {
		this(1200, host, crypto);
	}
	
	public QPSPEndpoint(int port, CryptoProvider crypto) {
		this(port, Host.empty(), crypto);
	}
	
	public QPSPEndpoint(int port, Host host, CryptoProvider crypto) {
		this(new PureUDPSocket(port), NullLogger.INSTANCE, host, crypto);
	}
	
	public void open() {
		try {
			socket.open();
			closed = false;
			new Receptor().start();
			
			for (QPSPConnection session : connectionsByID.values())
				session.resume();
		} catch (SocketException ex) {
			logger.log(NetworkLogger.CATEGORY_SETUP, 0, -1, ex.getMessage());
		}
	}
	
	public void pause() {
		closed = true;
		
		for (QPSPConnection session : connectionsByID.values())
			session.pause();
		
		socket.close();
	}
	
	public void close() {
		closed = true;
		
		for (QPSPConnection session : connectionsByID.values())
			session.close();
		
		socket.close();
	}
	
	public void connect(
			String host,
			SigningRootValidator rootValidator,
			ConnectionListener listener) throws IOException
	{
		connect(host, 1200, new EmptyServer(crypto.generateKeyPair()), rootValidator, listener, 20000, 120000);
	}
	
	public void connect(
			String host,
			int port,
			Server local,
			SigningRootValidator rootValidator,
			ConnectionListener listener,
			int keepaliveInterval,
			int maxKeepaliveInterval) throws IOException
	{
		long clientNonce = random.nextLong();
		long connectionId = connectionCounter++;
		
		QPSPConnectionRequest result = new QPSPConnectionRequest(
				this,
				local,
				host,
				port,
				local.getKeyPair(), 
				clientNonce,
				connectionId,
				rootValidator,
				listener,
				keepaliveInterval,
				maxKeepaliveInterval);
		connect(result);
	}
	
	private void connect(QPSPConnectionRequest request) throws IOException {
		BytesDataOutput output = new BytesDataOutput();
		output.writeVarULong(Constants.CONNECTION_SETUP);
		output.writeULong(request.clientNonce);
		output.writeVarULong(request.connectionId);
		output.writeVarUInt(1); // protocol version
		output.writeVarUInt(0); // protocol options
		output.writeString(request.host);
		output.writeRawBytes(request.keyPair.publicKey.encode());
		
		byte[] packetData = output.toByteArray();
		DatagramPacket packet = new DatagramPacket(packetData, packetData.length);
		packet.setAddress(InetAddress.getByName(request.host));
		packet.setPort(request.port);
		SetupPacket setup = new SetupPacket(packet, request.connectionId);
		setups.add(setup);
		setupRetransmitTimer.scheduleAtFixedRate(setup, SETUP_TIMEOUT, SETUP_TIMEOUT);
		
		requestedSessions.put(request.connectionId, request);
		
		socket.send(packet);
	}
	
	private class SetupPacket extends TimerTask {
		private final DatagramPacket packet;
		private final long connectionId;
		
		public SetupPacket(DatagramPacket packet, long connectionId) {
			this.packet = packet;
			this.connectionId = connectionId;
		}

		@Override
		public void run() {
			try {
				logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, 0, -1, "Retransmitting SETUP");
				socket.send(packet);
			} catch (IOException ex) {}
		}
	}
	
	public void send(QPSPConnection channel, byte[] packetData) throws IOException {
		DatagramPacket packet = new DatagramPacket(
				packetData, packetData.length,
				channel.remoteAddress,
				channel.remotePort);
		socket.send(packet);
	}
	
	private class Receptor extends Thread {
		@Override
		public void run() {
			try {
				while (!closed) {
					DatagramPacket packet = socket.receive();
					if (packet != null)
						onReceived(packet);
				}
			} catch (SocketException ex) {
				if (ex.getMessage().equals("socket closed")) // TODO: a better way to handle this
					return;
				
				ex.printStackTrace();
			} catch (IOException ex) {
				ex.printStackTrace();
			}
		}
	}
	
	private void onReceived(DatagramPacket packet) throws IOException {
		byte[] data = Arrays.copyOfRange(packet.getData(), packet.getOffset(), packet.getLength());
		BytesDataInput input = new BytesDataInput(data);
		long connectionID = input.readVarULong();
		if (connectionID == Constants.CONNECTION_FEEDBACK) {
			handleFeedback(packet, input);
		} else if (connectionID == Constants.CONNECTION_SETUP) {
			handleSetup(packet, input);
		} else if (connectionID == Constants.CONNECTION_INIT) {
			handleInit(packet, input);
		} else {
			QPSPConnection connection = connectionsByID.get(connectionID & ~1);
			if (connection == null ) {
				// unknown session
				logger.log(NetworkLogger.CATEGORY_STREAMS, connectionID, -1, "Connection doesn't exist");
				sendFeedback(packet.getAddress(), packet.getPort(), Constants.FEEDBACK_UNKNOWN_CONNECTION, connectionID);
			} else {
				boolean lossy = (connectionID & 1) > 0;
				connection.onReceived(packet.getAddress(), packet.getPort(), input, lossy);
			}
		}
	}
	
	private QPSPConnection findConnectionFromIPAndID(InetAddress address, int port, long remoteID) {
		for (QPSPConnection connection : connectionsByID.values()) {
			if (connection.remoteID == remoteID && connection.remoteAddress.equals(address) && connection.remotePort == port)
				return connection;
		}
		
		return null;
	}
	
	private void handleFeedback(DatagramPacket packet, BytesDataInput input) {
		int type = input.readVarUInt();
		long connectionID = input.readVarULong();
		
		QPSPConnection connection = findConnectionFromIPAndID(packet.getAddress(), packet.getPort(), connectionID);
		if (connection == null) {
			logger.log(NetworkLogger.CATEGORY_SETUP, 0, 0, "FEEDBACK with unknown connection");
			return;
		}

		if ((type & Constants.FEEDBACK_FLAG_RECONNECT) > 0) {
			connection.close();
			
			if (connection.request != null) {
				try {
					connect(connection.request.forReconnection(connectionCounter++, random.nextLong()));
				} catch (IOException ex) {
					logger.log(NetworkLogger.CATEGORY_SETUP, connection.localID, -1, "Reconnection failed: " + ex.getMessage());
				}
			}
		} else if ((type & Constants.FEEDBACK_FLAG_CLOSE) > 0) {
			connection.close();
		}
	}
	
	private void handleSetup(DatagramPacket packet, BytesDataInput input) {
		long clientNonce = input.readULong();
		long remoteConnectionID = input.readVarULong();

		int protocolVersion = input.readVarUInt();
		if (protocolVersion != Constants.VERSION)
			return; // rejected version

		int protocolFlags = input.readUByte(); // ignored
		byte[] domainNameBytes = input.readBytes();
		String domainName = new String(domainNameBytes, StandardCharsets.UTF_8);
		byte[] clientPublicKey = input.readRawBytes(CryptoProvider.PUBLIC_KEY_BYTES);
		logger.log(NetworkLogger.CATEGORY_SETUP, Constants.CONNECTION_SETUP, -1, "Received SETUP " + domainName);
		
		SessionKey key = new SessionKey(
				clientNonce,
				protocolVersion,
				protocolFlags,
				domainNameBytes,
				clientPublicKey);
		if (connectionsByKey.containsKey(key)) {
			QPSPConnection session = connectionsByKey.get(key);
			session.setRemote(packet.getAddress(), packet.getPort());
			session.sendInit();
			return;
		}
		
		Server server = host.getServer(domainName);
		if (server == null)
			return; // no such server
		
		long connectionID = connectionCounter;
		connectionCounter += 2;
		QPSPConnection session = new QPSPConnection(
				this, 
				packet.getAddress(),
				packet.getPort(),
				connectionID,
				remoteConnectionID,
				1024,
				4096,
				clientNonce,
				random.nextLong(),
				server,
				true,
				crypto.decodePublicKey(clientPublicKey),
				server.getKeyPair(),
				null);

		connectionsByID.put(connectionID, session);
		connectionsByKey.put(key, session);
		session.sendInit();
	}
	
	private void handleInit(DatagramPacket packet, BytesDataInput input) {
		logger.log(NetworkLogger.CATEGORY_SETUP, Constants.CONNECTION_INIT, -1, "Received INIT");
		
		long connectionId = input.readVarULong();
		long serverNonce = input.readULong();
		int options = input.readVarUInt();
		CryptoPublicKey serverPublicKey = crypto.decodePublicKey(input.readRawBytes(CryptoProvider.PUBLIC_KEY_BYTES));
		
		SetupPacket setup = null; // TODO - this can be optimized
		for (SetupPacket p : setups)
			if (p.connectionId == connectionId)
				setup = p;
		
		if (setup != null) {
			setup.cancel();
			setups.remove(setup);
		}
		
		QPSPConnectionRequest session = requestedSessions.remove(connectionId);
		if (session == null)
			return; // ignore
		
		session.preInit(serverNonce, serverPublicKey);
		try {
			byte[] decrypted = session.decryptInit(input.readByteArray());
			BytesDataInput decryptedInput = new BytesDataInput(decrypted);
			int flags = decryptedInput.readVarUInt();
			CertificateChain certificate = new CertificateChain(crypto, decryptedInput);
			long remoteStreamID = decryptedInput.readVarULong();
			int maxStreams = decryptedInput.readVarUInt();
			int maxPacketSize = decryptedInput.readVarUInt();
			int maxBufferSize = decryptedInput.readVarUInt();

			if (!session.isValidSigningRoot(certificate.rootKey))
				return; // TODO: report error
			if (!certificate.validate(session.host, serverPublicKey))
				return; // TODO: report error

			QPSPConnection connection = session.init(
					packet.getAddress(),
					packet.getPort(),
					remoteStreamID,
					maxPacketSize,
					maxBufferSize);
			connectionsByID.put(connection.localID, connection);
		} catch (CryptoDecryptionException ex) {
			logger.log(NetworkLogger.CATEGORY_SETUP, -1, -1, "Crypto exception on INIT packet");
		}
	}
	
	void sendFeedback(InetAddress address, int port, int type, long connectionID) {
		BytesDataOutput output = new BytesDataOutput();
		output.writeVarULong(Constants.CONNECTION_FEEDBACK);
		output.writeVarUInt(type);
		output.writeVarULong(connectionID & ~1);
		
		byte[] packetData = output.toByteArray();
		try {
			DatagramPacket packet = new DatagramPacket(
					packetData, packetData.length,
					address,
					port);
			socket.send(packet);
		} catch (IOException ex) {
			ex.printStackTrace();
		}
	}
	
	private static final class SessionKey {
		private final long clientNonce;
		private final int protocolVersion;
		private final int protocolFlags;
		private final byte[] domainName;
		private final byte[] clientPublicKey;
		
		public SessionKey(
				long clientNonce,
				int protocolVersion,
				int protocolFlags,
				byte[] domainName,
				byte[] clientPublicKey)
		{
			this.clientNonce = clientNonce;
			this.protocolVersion = protocolVersion;
			this.protocolFlags = protocolFlags;
			this.domainName = domainName;
			this.clientPublicKey = clientPublicKey;
		}

		@Override
		public int hashCode() {
			int hash = 3;
			hash = 97 * hash + (int) (this.clientNonce ^ (this.clientNonce >>> 32));
			hash = 97 * hash + this.protocolVersion;
			hash = 97 * hash + this.protocolFlags;
			hash = 97 * hash + Arrays.hashCode(this.domainName);
			hash = 97 * hash + Arrays.hashCode(this.clientPublicKey);
			return hash;
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (getClass() != obj.getClass())
				return false;
			
			final SessionKey other = (SessionKey) obj;
			return this.clientNonce == other.clientNonce
					&& this.protocolVersion == other.protocolVersion
					&& this.protocolFlags == other.protocolFlags
					&& Arrays.equals(this.domainName, other.domainName)
					&& Arrays.equals(this.clientPublicKey, other.clientPublicKey);
		}
	}
}
