/*
 * Decompiled with CFR 0.152.
 */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.SocketException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Timer;
import java.util.TimerTask;
import org.openzen.packetstreams.ConnectionListener;
import org.openzen.packetstreams.EmptyServer;
import org.openzen.packetstreams.Host;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.NullLogger;
import org.openzen.packetstreams.Server;
import org.openzen.packetstreams.SigningRootValidator;
import org.openzen.packetstreams.crypto.CertificateChain;
import org.openzen.packetstreams.crypto.CryptoDecryptionException;
import org.openzen.packetstreams.crypto.CryptoProvider;
import org.openzen.packetstreams.crypto.CryptoPublicKey;
import org.openzen.packetstreams.io.BytesDataInput;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.QPSPConnection;
import org.openzen.packetstreams.qpsp.QPSPConnectionRequest;
import org.openzen.packetstreams.qpsp.socket.PureUDPSocket;
import org.openzen.packetstreams.qpsp.socket.UDPSocket;

public class QPSPEndpoint {
    private static final int SETUP_TIMEOUT = 5000;
    public final CryptoProvider crypto;
    private final Host host;
    private final UDPSocket socket;
    public final NetworkLogger logger;
    private volatile boolean closed = false;
    private final Random random = new SecureRandom();
    private final Map<Long, QPSPConnection> connectionsByID = new HashMap<Long, QPSPConnection>();
    private final Map<Long, QPSPConnectionRequest> requestedSessions = new HashMap<Long, QPSPConnectionRequest>();
    private final Map<SessionKey, QPSPConnection> connectionsByKey = new HashMap<SessionKey, QPSPConnection>();
    private final Map<Long, SessionKey> sessionKeysByID = new HashMap<Long, SessionKey>();
    private final Timer setupRetransmitTimer = new Timer();
    private final List<SetupPacket> setups = new ArrayList<SetupPacket>();
    private long connectionCounter = 4L;
    private final HashSet<Long> reconnected = new HashSet();

    public QPSPEndpoint(UDPSocket socket, NetworkLogger logger, Host host, CryptoProvider crypto) {
        this.host = host;
        this.logger = logger;
        this.crypto = crypto;
        this.socket = socket;
    }

    public QPSPEndpoint(CryptoProvider crypto) {
        this(1200, Host.empty(), crypto);
    }

    public QPSPEndpoint(Host host, CryptoProvider crypto) {
        this(1200, host, crypto);
    }

    public QPSPEndpoint(int port, CryptoProvider crypto) {
        this(port, Host.empty(), crypto);
    }

    public QPSPEndpoint(int port, Host host, CryptoProvider crypto) {
        this(new PureUDPSocket(port), NullLogger.INSTANCE, host, crypto);
    }

    public void open() {
        try {
            this.socket.open();
            this.closed = false;
            new Receptor().start();
            for (QPSPConnection session : this.connectionsByID.values()) {
                session.resume();
            }
        }
        catch (SocketException ex) {
            this.logger.log(1, 0L, -1, ex.getMessage());
        }
    }

    public void pause() {
        this.closed = true;
        for (QPSPConnection session : this.connectionsByID.values()) {
            session.pause();
        }
        this.socket.close();
    }

    public void close() {
        this.closed = true;
        for (QPSPConnection session : this.connectionsByID.values()) {
            session.close(9);
        }
        this.socket.close();
    }

    public void connect(String host, SigningRootValidator rootValidator, ConnectionListener listener) throws IOException {
        this.connect(host, 1200, new EmptyServer(this.crypto.generateKeyPair()), rootValidator, listener, 20000, 120000);
    }

    public void connect(String host, int port, Server local, SigningRootValidator rootValidator, ConnectionListener listener, int keepaliveInterval, int maxKeepaliveInterval) throws IOException {
        long clientNonce = this.random.nextLong();
        long connectionId = this.connectionCounter;
        this.connectionCounter += 2L;
        QPSPConnectionRequest result = new QPSPConnectionRequest(this, local, host, port, local.getKeyPair(), clientNonce, connectionId, rootValidator, listener, keepaliveInterval, maxKeepaliveInterval);
        this.connect(result);
    }

    private void connect(QPSPConnectionRequest request) throws IOException {
        BytesDataOutput output = new BytesDataOutput();
        output.writeVarULong(2L);
        output.writeULong(request.clientNonce);
        output.writeVarULong(request.connectionId);
        output.writeVarUInt(1);
        output.writeVarUInt(0);
        output.writeString(request.host);
        output.writeRawBytes(request.keyPair.publicKey.encode());
        byte[] packetData = output.toByteArray();
        DatagramPacket packet = new DatagramPacket(packetData, packetData.length);
        packet.setAddress(InetAddress.getByName(request.host));
        packet.setPort(request.port);
        SetupPacket setup = new SetupPacket(packet, request.connectionId);
        this.setups.add(setup);
        this.setupRetransmitTimer.scheduleAtFixedRate((TimerTask)setup, 5000L, 5000L);
        this.requestedSessions.put(request.connectionId, request);
        this.socket.send(packet);
    }

    public void send(QPSPConnection channel, byte[] packetData) throws IOException {
        DatagramPacket packet = new DatagramPacket(packetData, packetData.length, channel.remoteAddress, channel.remotePort);
        this.socket.send(packet);
    }

    private void onReceived(DatagramPacket packet) throws IOException {
        byte[] data = Arrays.copyOfRange(packet.getData(), packet.getOffset(), packet.getLength());
        BytesDataInput input = new BytesDataInput(data);
        long connectionID = input.readVarULong();
        if (connectionID == 1L) {
            this.handleFeedback(packet, input);
        } else if (connectionID == 2L) {
            this.handleSetup(packet, input);
        } else if (connectionID == 3L) {
            this.handleInit(packet, input);
        } else {
            QPSPConnection connection = this.connectionsByID.get(connectionID & 0xFFFFFFFFFFFFFFFEL);
            if (connection == null) {
                this.logger.log(2, connectionID, -1, "Connection doesn't exist");
                this.sendFeedback(packet.getAddress(), packet.getPort(), 17, connectionID);
            } else {
                boolean lossy = (connectionID & 1L) > 0L;
                connection.onReceived(packet.getAddress(), packet.getPort(), input, lossy);
            }
        }
    }

    private QPSPConnection findConnectionFromIPAndID(InetAddress address, int port, long remoteID) {
        for (QPSPConnection connection : this.connectionsByID.values()) {
            if (connection.remoteID != remoteID || !connection.remoteAddress.equals(address) || connection.remotePort != port) continue;
            return connection;
        }
        return null;
    }

    private void handleFeedback(DatagramPacket packet, BytesDataInput input) {
        int type = input.readVarUInt();
        long connectionID = input.readVarULong();
        QPSPConnection connection = this.findConnectionFromIPAndID(packet.getAddress(), packet.getPort(), connectionID);
        if (connection == null) {
            this.logger.log(1, 0L, 0, "FEEDBACK with unknown connection");
            return;
        }
        if ((type & 1) > 0) {
            if (this.reconnected.contains(connection.localID)) {
                return;
            }
            this.reconnected.add(connection.localID);
            connection.close(4);
            if (connection.request != null) {
                try {
                    long newID = this.connectionCounter;
                    this.connectionCounter += 2L;
                    this.connect(connection.request.forReconnection(newID, this.random.nextLong()));
                }
                catch (IOException ex) {
                    this.logger.log(1, connection.localID, -1, "Reconnection failed: " + ex.getMessage());
                }
            }
        } else if ((type & 2) > 0) {
            connection.close(4);
        }
    }

    private void handleSetup(DatagramPacket packet, BytesDataInput input) {
        long clientNonce = input.readULong();
        long remoteConnectionID = input.readVarULong();
        int protocolVersion = input.readVarUInt();
        if (protocolVersion != 1) {
            return;
        }
        int protocolFlags = input.readUByte();
        byte[] domainNameBytes = input.readBytes();
        String domainName = new String(domainNameBytes, StandardCharsets.UTF_8);
        byte[] clientPublicKey = input.readRawBytes(32);
        this.logger.log(1, 2L, -1, "Received SETUP " + domainName);
        SessionKey key = new SessionKey(clientNonce, protocolVersion, protocolFlags, domainNameBytes, clientPublicKey);
        if (this.connectionsByKey.containsKey(key)) {
            QPSPConnection session = this.connectionsByKey.get(key);
            session.setRemote(packet.getAddress(), packet.getPort());
            session.sendInit();
            return;
        }
        Server server = this.host.getServer(domainName);
        if (server == null) {
            return;
        }
        long connectionID = this.connectionCounter;
        this.connectionCounter += 2L;
        QPSPConnection session = new QPSPConnection(this, packet.getAddress(), packet.getPort(), connectionID, remoteConnectionID, 1024, 4096, clientNonce, this.random.nextLong(), server, true, this.crypto.decodePublicKey(clientPublicKey), server.getKeyPair(), null);
        this.connectionsByID.put(connectionID, session);
        this.connectionsByKey.put(key, session);
        this.sessionKeysByID.put(connectionID, key);
        session.sendInit();
    }

    private void handleInit(DatagramPacket packet, BytesDataInput input) {
        QPSPConnectionRequest session;
        this.logger.log(1, 3L, -1, "Received INIT");
        long connectionId = input.readVarULong();
        long serverNonce = input.readULong();
        int options = input.readVarUInt();
        CryptoPublicKey serverPublicKey = this.crypto.decodePublicKey(input.readRawBytes(32));
        TimerTask setup = null;
        for (SetupPacket p : this.setups) {
            if (p.connectionId != connectionId) continue;
            setup = p;
        }
        if (setup != null) {
            setup.cancel();
            this.setups.remove(setup);
        }
        if ((session = this.requestedSessions.remove(connectionId)) == null) {
            return;
        }
        session.preInit(serverNonce, serverPublicKey);
        try {
            byte[] decrypted = session.decryptInit(input.readByteArray());
            BytesDataInput decryptedInput = new BytesDataInput(decrypted);
            int flags = decryptedInput.readVarUInt();
            CertificateChain certificate = new CertificateChain(this.crypto, decryptedInput);
            long remoteStreamID = decryptedInput.readVarULong();
            int maxStreams = decryptedInput.readVarUInt();
            int maxPacketSize = decryptedInput.readVarUInt();
            int maxBufferSize = decryptedInput.readVarUInt();
            if (!session.isValidSigningRoot(certificate.rootKey)) {
                return;
            }
            if (!certificate.validate(session.host, serverPublicKey)) {
                return;
            }
            QPSPConnection connection = session.init(packet.getAddress(), packet.getPort(), remoteStreamID, maxPacketSize, maxBufferSize);
            this.connectionsByID.put(connection.localID, connection);
        }
        catch (CryptoDecryptionException ex) {
            this.logger.log(1, -1L, -1, "Crypto exception on INIT packet");
        }
    }

    void sendFeedback(InetAddress address, int port, int type, long connectionID) {
        BytesDataOutput output = new BytesDataOutput();
        output.writeVarULong(1L);
        output.writeVarUInt(type);
        output.writeVarULong(connectionID & 0xFFFFFFFFFFFFFFFEL);
        byte[] packetData = output.toByteArray();
        try {
            DatagramPacket packet = new DatagramPacket(packetData, packetData.length, address, port);
            this.socket.send(packet);
        }
        catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    void onClosed(QPSPConnection connection) {
        this.connectionsByID.remove(connection.localID);
        if (this.sessionKeysByID.containsKey(connection.localID)) {
            SessionKey key = this.sessionKeysByID.remove(connection.localID);
            this.connectionsByKey.remove(key);
        }
    }

    private static final class SessionKey {
        private final long clientNonce;
        private final int protocolVersion;
        private final int protocolFlags;
        private final byte[] domainName;
        private final byte[] clientPublicKey;

        public SessionKey(long clientNonce, int protocolVersion, int protocolFlags, byte[] domainName, byte[] clientPublicKey) {
            this.clientNonce = clientNonce;
            this.protocolVersion = protocolVersion;
            this.protocolFlags = protocolFlags;
            this.domainName = domainName;
            this.clientPublicKey = clientPublicKey;
        }

        public int hashCode() {
            int hash = 3;
            hash = 97 * hash + (int)(this.clientNonce ^ this.clientNonce >>> 32);
            hash = 97 * hash + this.protocolVersion;
            hash = 97 * hash + this.protocolFlags;
            hash = 97 * hash + Arrays.hashCode(this.domainName);
            hash = 97 * hash + Arrays.hashCode(this.clientPublicKey);
            return hash;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            SessionKey other = (SessionKey)obj;
            return this.clientNonce == other.clientNonce && this.protocolVersion == other.protocolVersion && this.protocolFlags == other.protocolFlags && Arrays.equals(this.domainName, other.domainName) && Arrays.equals(this.clientPublicKey, other.clientPublicKey);
        }
    }

    private class Receptor
    extends Thread {
        private Receptor() {
        }

        @Override
        public void run() {
            try {
                while (!QPSPEndpoint.this.closed) {
                    DatagramPacket packet = QPSPEndpoint.this.socket.receive();
                    if (packet == null) continue;
                    QPSPEndpoint.this.onReceived(packet);
                }
            }
            catch (SocketException ex) {
                if (ex.getMessage().equals("socket closed")) {
                    return;
                }
                ex.printStackTrace();
            }
            catch (IOException ex) {
                ex.printStackTrace();
            }
        }
    }

    private class SetupPacket
    extends TimerTask {
        private final DatagramPacket packet;
        private final long connectionId;

        public SetupPacket(DatagramPacket packet, long connectionId) {
            this.packet = packet;
            this.connectionId = connectionId;
        }

        @Override
        public void run() {
            try {
                QPSPEndpoint.this.logger.log(16, 0L, -1, "Retransmitting SETUP");
                QPSPEndpoint.this.socket.send(this.packet);
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }
    }
}

