/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.ClientSession;
import org.openzen.packetstreams.Session;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import org.openzen.packetstreams.SigningRootValidator;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoKeyPair;
import org.openzen.packetstreams.crypto.CryptoPrivateKey;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.crypto.CryptoSharedKey;
import org.openzen.packetstreams.crypto.CryptoVerifyKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.crypto.CryptoProvider;

/**
 * Represents a client-initiating session. Only used during initialization.
 */
public class QPSPClientSession implements ClientSession {
	private final QPSPEndpoint endpoint;
	private final CryptoKeyPair keyPair;
	private final long clientNonce;
	public final String host;
	private final InetAddress address;
	private final int port;
	private final List<Consumer<Session>> whenEstablished = new ArrayList<>();
	private final Server server;
	private final SigningRootValidator rootValidator;
	private final int keepaliveInterval;
	private final int maxKeepaliveInterval;
	
	private boolean established = false;
	private long serverNonce;
	private CryptoPublicKey serverKey;
	private QPSPSession session;
	
	public QPSPClientSession(
			byte[] session,
			QPSPEndpoint endpoint,
			Server server,
			SigningRootValidator rootValidator,
			int keepaliveInterval,
			int maxKeepaliveInterval) throws IOException {
		this.endpoint = endpoint;
		this.server = server;
		this.rootValidator = rootValidator;
		this.keepaliveInterval = keepaliveInterval;
		this.maxKeepaliveInterval = maxKeepaliveInterval;
		
		BytesDataInput input = new BytesDataInput(session);
		CryptoPrivateKey privateKey = endpoint.crypto.decodePrivateKey(input.readByteArray());
		CryptoPublicKey publicKey = endpoint.crypto.decodePublicKey(input.readByteArray());
		keyPair = new CryptoKeyPair(privateKey, publicKey);
		serverKey = endpoint.crypto.decodePublicKey(input.readByteArray());
		clientNonce = input.readULong();
		serverNonce = input.readULong();
		host = input.readString();
		port = input.readVarUInt();
		long localFromStreamId = input.readVarULong();
		long localToStreamId = input.readVarULong();
		long remoteFromStreamId = input.readVarULong();
		long remoteToStreamId = input.readVarULong();
		long streamCounter = input.readVarULong();
		int maxPacketSize = input.readVarUInt();
		int maxBufferSize = input.readVarUInt();
		
		this.address = InetAddress.getByName(host);
		this.session = new QPSPSession(
				endpoint,
				address,
				port,
				localFromStreamId,
				localToStreamId,
				maxPacketSize,
				maxBufferSize,
				serverNonce,
				clientNonce,
				server,
				publicKey,
				keyPair);
		this.session.initFromStorage(streamCounter, remoteFromStreamId, remoteToStreamId);
		
		established = true;
	}
	
	public QPSPClientSession(
			QPSPEndpoint endpoint,
			Server local,
			String host, int port,
			CryptoKeyPair keyPair,
			long clientNonce,
			SigningRootValidator rootValidator,
			int keepaliveInterval,
			int maxKeepaliveInterval) throws UnknownHostException {
		this.host = host;
		this.port = port;
		this.endpoint = endpoint;
		this.keyPair = keyPair;
		this.clientNonce = clientNonce;
		this.server = local;
		
		this.address = InetAddress.getByName(host);
		this.rootValidator = rootValidator;
		this.keepaliveInterval = keepaliveInterval;
		this.maxKeepaliveInterval = maxKeepaliveInterval;
	}
	
	public void preInit(long serverNonce, CryptoPublicKey publicKey) {
		this.serverNonce = serverNonce;
		this.serverKey = publicKey;
	}
	
	public boolean isValidSigningRoot(CryptoVerifyKey rootKey) {
		return rootValidator.isValid(rootKey);
	}
	
	public QPSPSession init(
			long fromStreamId,
			long toStreamId,
			int maxPacketSize,
			int maxBufferSize) {
		established = true;
		
		QPSPSession result = new QPSPSession(
				endpoint,
				address, port,
				fromStreamId,
				toStreamId,
				maxPacketSize,
				maxBufferSize,
				serverNonce,
				clientNonce,
				server,
				serverKey,
				keyPair);
		result.init(fromStreamId, toStreamId, maxPacketSize, maxBufferSize, keepaliveInterval, maxKeepaliveInterval);
		
		for (Consumer<Session> listener : whenEstablished)
			listener.accept(result);
		return result;
	}
	
	@Override
	public void whenEstablished(Consumer<Session> listener) {
		whenEstablished.add(listener);
	}
	
	public byte[] decryptInit(byte[] init) throws CryptoDecryptionException {
		byte[] nonce = new byte[CryptoProvider.NONCE_BYTES];
		QPSPSession.setLong(nonce, 0, clientNonce);
		
		CryptoSharedKey sharedKey = endpoint.crypto.createSharedKey(serverKey, keyPair.privateKey);
		return sharedKey.decrypt(nonce, init, 0, init.length);
	}
	
	@Override
	public byte[] serialize() {
		if (!established)
			return new byte[0];
		
		BytesDataOutput output = new BytesDataOutput();
		output.writeByteArray(keyPair.privateKey.encode());
		output.writeByteArray(keyPair.publicKey.encode());
		output.writeByteArray(serverKey.encode());
		output.writeULong(clientNonce);
		output.writeULong(serverNonce);
		output.writeString(host);
		output.writeVarUInt(port);
		output.writeVarULong(session.localFromStreamID);
		output.writeVarULong(session.localToStreamID);
		output.writeVarULong(session.remoteFromStreamID);
		output.writeVarULong(session.remoteToStreamID);
		output.writeVarULong(session.getRemoteStreamCounter());
		return output.toByteArray();
	}
}
