From fe4445cf9a8bc6652440d2a29356cf948bc265a9 Mon Sep 17 00:00:00 2001 From: Jeroen Robben Date: Tue, 16 Jun 2026 01:44:54 +0200 Subject: [PATCH 1/2] fix: tcp flow-control --- .../vpn/SlowClientDownloadInstrumentedTest.kt | 215 ++++++++++++++++++ .../tech/httptoolkit/android/vpn/Session.java | 30 ++- .../android/vpn/SessionHandler.java | 28 ++- .../vpn/socket/SocketChannelReader.java | 101 +++++--- .../vpn/socket/SocketNIODataService.java | 10 + .../android/vpn/TcpDownloadFlowControlTest.kt | 106 +++++++++ .../android/vpn/TcpWindowedDownloadTest.kt | 214 +++++++++++++++++ .../httptoolkit/android/vpn/TestPackets.kt | 3 +- 8 files changed, 654 insertions(+), 53 deletions(-) create mode 100644 app/src/androidTest/java/tech/httptoolkit/android/vpn/SlowClientDownloadInstrumentedTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/TcpDownloadFlowControlTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/TcpWindowedDownloadTest.kt diff --git a/app/src/androidTest/java/tech/httptoolkit/android/vpn/SlowClientDownloadInstrumentedTest.kt b/app/src/androidTest/java/tech/httptoolkit/android/vpn/SlowClientDownloadInstrumentedTest.kt new file mode 100644 index 0000000..5f71d32 --- /dev/null +++ b/app/src/androidTest/java/tech/httptoolkit/android/vpn/SlowClientDownloadInstrumentedTest.kt @@ -0,0 +1,215 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import androidx.room.Room +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import org.distrinet.lanshield.database.AppDatabase +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import tech.httptoolkit.android.vpn.socket.IProtectSocket +import tech.httptoolkit.android.vpn.socket.SocketNIODataService +import tech.httptoolkit.android.vpn.socket.SocketProtector +import tech.httptoolkit.android.vpn.transport.ip.IPPacketFactory +import tech.httptoolkit.android.vpn.transport.tcp.TCPPacketFactory +import java.io.File +import java.io.FileOutputStream +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.nio.ByteBuffer +import java.util.concurrent.Executors +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit + +/** + * On-device (instrumented) end-to-end test of the TCP download flow-control fix, running the + * REAL forwarding engine (SessionHandler / SessionManager / SocketNIODataService) against a + * real loopback server on the device. + * + * It models a slow-draining client (like Firefox) by acting as the client's TCP stack: it + * advertises a small receive window and acks the contiguous prefix as it consumes each + * segment. Because the engine now respects that window, it never over-sends, so nothing is + * dropped and the whole file is delivered in order — the on-device proof that the download + * completes instead of stalling. + * + * Self-contained (the unit-test ForwardingTestHarness/TestPackets are not visible to + * androidTest), so it inlines a minimal capturing writer and TCP packet builder. + */ +@RunWith(AndroidJUnit4::class) +class SlowClientDownloadInstrumentedTest { + + private lateinit var db: AppDatabase + private lateinit var writer: CapturingWriter + private lateinit var nioService: SocketNIODataService + private lateinit var nioThread: Thread + private lateinit var sessionManager: SessionManager + private lateinit var sessionHandler: SessionHandler + private lateinit var tempTun: File + + private lateinit var server: ServerSocket + private val executor = Executors.newSingleThreadExecutor() + private var peerPort = 0 + + private val clientIp = "10.0.0.2" + private val clientPort = 50000 + private val peerIp = "127.0.0.1" + private val clientIsn = 1000L + private val window = 8 * 1024 // small receive window the "client" advertises + private val responseSize = 256 * 1024 // server pushes far more than one window + + private class CapturingWriter(out: FileOutputStream) : ClientPacketWriter(out) { + val queue = LinkedBlockingQueue() + override fun write(data: ByteArray) { queue.add(data) } + } + + @Before + fun setUp() { + SocketProtector.getInstance().setProtector(object : IProtectSocket { + override fun protect(socket: Socket): Boolean = true + override fun protect(socket: DatagramSocket): Boolean = true + }) + val ctx = ApplicationProvider.getApplicationContext() + db = Room.inMemoryDatabaseBuilder(ctx, AppDatabase::class.java).allowMainThreadQueries().build() + + tempTun = File.createTempFile("tun", ".bin").apply { deleteOnExit() } + writer = CapturingWriter(FileOutputStream(tempTun)) + nioService = SocketNIODataService(writer, db) + nioThread = Thread(nioService, "nio-instrumented").apply { isDaemon = true; start() } + sessionManager = SessionManager(db) + sessionHandler = SessionHandler(sessionManager, nioService, writer, db) + + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + } + + @After + fun tearDown() { + executor.shutdownNow() + runCatching { server.close() } + runCatching { nioService.shutdown() } + runCatching { nioThread.join(1000) } + runCatching { db.close() } + runCatching { tempTun.delete() } + } + + @Test + fun windowedClientReceivesTheWholeDownloadInOrder() { + val accepted = executor.submit { server.accept() } + + // Handshake: SYN -> SYN-ACK -> ACK, advertising the small window. + feed(syn()) + val synAck = awaitMatching { val (_, t) = parseTcp(it); t.isSYN && t.isACK } + val serverIsn = parseTcp(synAck).second.sequenceNumber + feed(ack(serverIsn + 1, window)) + val serverSocket = accepted.get(5, TimeUnit.SECONDS) + + // Server streams a response many windows in size, on a background thread (the engine + // backpressures the upstream, so a single blocking write drains gradually as we ack). + executor.submit { + runCatching { serverSocket.getOutputStream().apply { write(ByteArray(responseSize)); flush() } } + } + + // Conformant windowed client: accept each in-order segment, verify it never exceeds the + // advertised window, and ack the contiguous prefix to keep the window open. + val firstByte = serverIsn + 1 + var received = 0L + var expectedSeq = firstByte and 0xFFFFFFFFL + while (received < responseSize) { + val pkt = writer.queue.poll(5000, TimeUnit.MILLISECONDS) + ?: throw AssertionError("download stalled after $received / $responseSize bytes") + val (_, tcp) = parseTcp(pkt) + val len = tcpPayloadLength(pkt) + if (len <= 0) continue + assertEquals("segments must arrive strictly in order", expectedSeq, tcp.sequenceNumber and 0xFFFFFFFFL) + assertTrue("segment ($len B) exceeded the advertised window ($window B)", len <= window) + expectedSeq = (expectedSeq + len) and 0xFFFFFFFFL + received += len + feed(ack(firstByte + received, window)) // ack the contiguous prefix, reopening the window + } + + assertEquals("the windowed client must receive the whole file", responseSize.toLong(), received) + } + + // --- engine plumbing helpers --------------------------------------------- + + private fun feed(packet: ByteArray) { + sessionHandler.handlePacket(ByteBuffer.wrap(packet), "test.app") + } + + private fun parseTcp(packet: ByteArray) = ByteBuffer.wrap(packet).let { b -> + val ip = IPPacketFactory.createIPHeader(b) + val tcp = TCPPacketFactory.createTCPHeader(b) + ip to tcp + } + + private fun awaitMatching(timeoutMs: Long = 5000, predicate: (ByteArray) -> Boolean): ByteArray { + val deadline = System.nanoTime() + timeoutMs * 1_000_000 + while (System.nanoTime() < deadline) { + val remaining = ((deadline - System.nanoTime()) / 1_000_000).coerceAtLeast(1) + val pkt = writer.queue.poll(remaining, TimeUnit.MILLISECONDS) ?: break + if (predicate(pkt)) return pkt + } + throw AssertionError("No matching TUN packet within ${timeoutMs}ms") + } + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } + + // --- raw IPv4 TCP packet builders (minimal, mirrors unit-test TestPackets) ------ + + private val syn0 = 0x02 + private val ack0 = 0x10 + + private fun syn(): ByteArray = tcpPacket(clientIsn, 0, syn0, mss = 1460, win = window) + private fun ack(ackNum: Long, win: Int): ByteArray = + tcpPacket(clientIsn + 1, ackNum, ack0, mss = null, win = win) + + private fun tcpPacket(seq: Long, ack: Long, flags: Int, mss: Int?, win: Int): ByteArray { + val optionBytes = if (mss != null) 4 else 0 + val tcpHeaderLen = 20 + optionBytes + val total = 20 + tcpHeaderLen + val buf = ByteArray(total) + // IPv4 header + buf[0] = 0x45 + putShort(buf, 2, total) + buf[8] = 64 // TTL + buf[9] = 6 // protocol TCP + ipBytes(clientIp).copyInto(buf, 12) + ipBytes(peerIp).copyInto(buf, 16) + // TCP header + val t = 20 + putShort(buf, t, clientPort) + putShort(buf, t + 2, peerPort) + putInt(buf, t + 4, seq) + putInt(buf, t + 8, ack) + buf[t + 12] = ((tcpHeaderLen / 4) shl 4).toByte() + buf[t + 13] = flags.toByte() + putShort(buf, t + 14, win) + if (mss != null) { + buf[t + 20] = 0x02; buf[t + 21] = 0x04; putShort(buf, t + 22, mss) + } + return buf + } + + private fun ipBytes(dotted: String) = + ByteArray(4) { dotted.split(".")[it].toInt().toByte() } + + private fun putShort(b: ByteArray, off: Int, v: Int) { + b[off] = ((v ushr 8) and 0xFF).toByte(); b[off + 1] = (v and 0xFF).toByte() + } + + private fun putInt(b: ByteArray, off: Int, v: Long) { + b[off] = ((v ushr 24) and 0xFF).toByte(); b[off + 1] = ((v ushr 16) and 0xFF).toByte() + b[off + 2] = ((v ushr 8) and 0xFF).toByte(); b[off + 3] = (v and 0xFF).toByte() + } +} diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/Session.java b/app/src/main/java/tech/httptoolkit/android/vpn/Session.java index af5ea16..d2f98ba 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/Session.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/Session.java @@ -93,10 +93,12 @@ public class Session { //indicate data from client is ready for sending to destination private volatile boolean isDataForSendingReady = false; - - //store data for retransmission - private byte[] unackData = null; - + + //client-direction TCP flow control (download path) + private long clientWindow = 0; // advertised receive window, scaled to bytes + private int clientWindowScale = 0; // window scale from the client SYN (RFC 7323) + private boolean upstreamEof = false; // upstream hit EOF; FIN deferred until staging drains + //in ACK packet from client, if the previous packet was corrupted, client will send flag in options field private boolean packetCorrupted = false; @@ -171,10 +173,22 @@ public synchronized byte[] getReceivedData(int maxSize){ * buffer has more data for vpn client * @return boolean */ - public boolean hasReceivedData(){ + public synchronized boolean hasReceivedData(){ return receivingStream.size() > 0; } + public synchronized int receivingStreamSize(){ + return receivingStream.size(); + } + + public synchronized long getClientWindow(){ return clientWindow; } + public synchronized void setClientWindow(long clientWindow){ this.clientWindow = clientWindow; } + public synchronized int getClientWindowScale(){ return clientWindowScale; } + public synchronized void setClientWindowScale(int clientWindowScale){ this.clientWindowScale = clientWindowScale; } + + public synchronized boolean isUpstreamEof(){ return upstreamEof; } + public synchronized void setUpstreamEof(boolean upstreamEof){ this.upstreamEof = upstreamEof; } + /** * set data to be sent to destination server. * For UDP each call is queued as a discrete datagram @@ -245,7 +259,7 @@ public int getDestPort() { return destPort; } - long getSendUnack() { + public long getSendUnack() { return sendUnack; } @@ -342,10 +356,6 @@ public boolean isDataForSendingReady() { public void setDataForSendingReady(boolean isDataForSendingReady) { this.isDataForSendingReady = isDataForSendingReady; } - public void setUnackData(byte[] unackData) { - this.unackData = unackData; - } - void setPacketCorrupted(boolean packetCorrupted) { this.packetCorrupted = packetCorrupted; } diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java b/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java index 0840334..cf3e5dc 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java @@ -173,6 +173,11 @@ private void handleTCPPacket(ByteBuffer clientPacketData, IPHeader ipHeader, Str session.setLastTcpHeader(tcpheader); LANFlow lanFlow = session.getFlow(); + // Update the advertised window on every ACK (unconditionally, since pure + // window-update ACKs don't pass acceptAck's gate below). The 16-bit window is + // parsed as a signed short, so mask to unsigned before scaling. + session.setClientWindow((long) (tcpheader.getWindowSize() & 0xFFFF) << session.getClientWindowScale()); + //any data from client? if (dataLength > 0) { @@ -220,6 +225,9 @@ private void handleTCPPacket(ByteBuffer clientPacketData, IPHeader ipHeader, Str if (!session.isAbortingConnection()) { manager.keepSessionAlive(session); + // This ACK may have reopened the window: flush staged data and resume reads. + nioService.pumpToClient(session); + nioService.refreshSelect(session); } } } else if(tcpheader.isFIN()){ @@ -367,13 +375,18 @@ private void acceptAck(TCPHeader tcpHeader, Session session){ Log.e(TAG,"prev packet was corrupted, last ack# " + tcpHeader.getAckNumber()); } - if ( - tcpHeader.getAckNumber() > session.getSendUnack() || - tcpHeader.getAckNumber() == session.getSendNext() - ) { + // Reconstruct the ACK in absolute sequence space. The unacked span (sendNext - ack) is at + // most one window, so taking it mod 2^32 and subtracting recovers the absolute ack even + // after the 32-bit sequence wraps (e.g. partway through a multi-GB download) — where the + // old signed comparison would have stopped accepting ACKs and stalled. + long sendNextAbs = session.getSendNext(); + long unackedSpan = (sendNextAbs - (tcpHeader.getAckNumber() & 0xFFFFFFFFL)) & 0xFFFFFFFFL; + long ackAbs = sendNextAbs - unackedSpan; + + if (ackAbs > session.getSendUnack() || ackAbs == sendNextAbs) { session.setAcked(true); - session.setSendUnack(tcpHeader.getAckNumber()); + session.setSendUnack(ackAbs); session.setRecSequence(tcpHeader.getSequenceNumber()); session.setTimestampReplyto(tcpHeader.getTimeStampSender()); session.setTimestampSender((int) System.currentTimeMillis()); @@ -433,6 +446,11 @@ private void replySynAck(IPHeader ip, TCPHeader tcp, String packageName) throws //client initial sequence has been incremented by 1 and set to ack session.setRecSequence(tcpheader.getAckNumber()); + // Capture the client's flow-control params. The SYN window is unscaled (RFC 7323) and + // parsed as a signed short, so mask to unsigned. + session.setClientWindowScale(tcp.getWindowScale()); + session.setClientWindow(tcp.getWindowSize() & 0xFFFF); + session.setLastIpHeader(ip); session.setLastTcpHeader(tcp); diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelReader.java b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelReader.java index f6b85ed..818cbd8 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelReader.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelReader.java @@ -37,6 +37,10 @@ public SocketChannelReader(ClientPacketWriter writer) { this.writer = writer; } + // When staged-but-unsent upstream bytes reach this, stop reading so the upstream TCP window + // closes and the sender backs off, instead of pulling a multi-GB download into memory. + private static final int STAGING_CAP = 2 * DataConst.MAX_RECEIVE_BUFFER_SIZE; + public long read(Session session) { AbstractSelectableChannel channel = session.getChannel(); long bytesRead = 0; @@ -48,8 +52,11 @@ public long read(Session session) { return 0; } - // Resubscribe to reads, so that we're triggered again if more data arrives later. - session.subscribeKey(SelectionKey.OP_READ); + // Resubscribe to reads, unless backpressure is holding us off (staging full): pumpToClient + // re-subscribes once it drains below the cap. + if (!(channel instanceof SocketChannel) || session.receivingStreamSize() < STAGING_CAP) { + session.subscribeKey(SelectionKey.OP_READ); + } if (session.isAbortingConnection()) { Log.d(TAG,"removing aborted connection -> "+ session); @@ -90,16 +97,21 @@ private long readTCP(@NonNull Session session) { try { do { + // Backpressure: stop reading once the staging buffer is full. + if (session.receivingStreamSize() >= STAGING_CAP) { + session.unsubscribeKey(SelectionKey.OP_READ); + break; + } len = channel.read(buffer); if (len > 0) { //-1 mean it reach the end of stream sendToRequester(buffer, len, session); buffer.clear(); bytesRead += len; } else if (len == -1) { - Log.d(TAG,"End of data from remote server, will send FIN to client"); - Log.d(TAG,"send FIN to: " + session); - sendFin(session); - session.setAbortingConnection(true); + // EOF: defer the FIN to pumpToClient so it can't overtake unsent staged data. + Log.d(TAG,"End of data from remote server, will FIN once drained: " + session); + session.setUpstreamEof(true); + pumpToClient(session); } } while (len > 0); }catch(NotYetConnectedException e){ @@ -126,47 +138,62 @@ private void sendToRequester(ByteBuffer buffer, int dataSize, @NonNull Session s byte[] data = new byte[dataSize]; System.arraycopy(buffer.array(), 0, data, 0, dataSize); session.addReceivedData(data); - //pushing all data to vpn client - while(session.hasReceivedData()){ - pushDataToClient(session); - } + pumpToClient(session); } + + private int maxSegment(@NonNull Session session){ + // TODO What does 60 mean? Leaves room for IP + TCP options below the MSS. + int max = session.getMaxSegmentSize() - 60; + return max < 1 ? 1024 : max; + } + /** - * create packet data and send it to VPN client - * @param session Session + * Send staged upstream data to the VPN client, keeping at most one client window in flight + * (clientWindow - (sendNext - sendUnack) bytes). Runs under the session monitor, from the NIO + * thread after a read and from the SessionHandler thread after a window-opening ACK. Sends the + * deferred FIN once the upstream is done and all staged data has drained. Sequence math is + * unsigned 32-bit so a multi-GB transfer (which wraps the sequence number) stays correct. */ - private void pushDataToClient(@NonNull Session session){ - if (!session.hasReceivedData()) { - //no data to send - Log.d(TAG,"no data for vpn client"); - } - + void pumpToClient(@NonNull Session session){ IPHeader ipHeader = session.getLastIpHeader(); TCPHeader tcpheader = session.getLastTcpHeader(); - // TODO What does 60 mean? - int max = session.getMaxSegmentSize() - 60; + if (ipHeader == null || tcpheader == null) return; + + final int segMax = maxSegment(session); + + while (session.hasReceivedData()) { + long inFlight = unsigned32(session.getSendNext() - session.getSendUnack()); + long room = session.getClientWindow() - inFlight; + if (room <= 0) break; // window full (or zero window) + int chunk = (int) Math.min(room, segMax); + + byte[] packetBody = session.getReceivedData(chunk); + if (packetBody == null || packetBody.length == 0) break; + + long seq = unsigned32(session.getSendNext()); + session.setSendNext(session.getSendNext() + packetBody.length); - if(max < 1) { - max = 1024; + boolean psh = session.hasReceivedLastSegment() && !session.hasReceivedData(); + writer.write(TCPPacketFactory.createResponsePacketData(ipHeader, tcpheader, packetBody, + psh, session.getRecSequence(), seq, + session.getTimestampSender(), session.getTimestampReplyto())); } - byte[] packetBody = session.getReceivedData(max); - if(packetBody != null && packetBody.length > 0) { - long unAck = session.getSendNext(); - long nextUnAck = session.getSendNext() + packetBody.length; - session.setSendNext(nextUnAck); - //we need this data later on for retransmission - session.setUnackData(packetBody); - session.setResendPacketCounter(0); - - byte[] data = TCPPacketFactory.createResponsePacketData(ipHeader, - tcpheader, packetBody, session.hasReceivedLastSegment(), - session.getRecSequence(), unAck, - session.getTimestampSender(), session.getTimestampReplyto()); - - writer.write(data); + if (session.isUpstreamEof() && !session.hasReceivedData() && !session.isAbortingConnection()) { + sendFin(session); + session.setAbortingConnection(true); + return; + } + + // Resume upstream reads if backpressure stopped them and staging has drained. + if (!session.isUpstreamEof() && session.receivingStreamSize() < STAGING_CAP) { + session.subscribeKey(SelectionKey.OP_READ); } } + + private static long unsigned32(long value){ + return value & 0xFFFFFFFFL; + } private void sendFin(Session session){ final IPHeader ipHeader = session.getLastIpHeader(); final TCPHeader tcpheader = session.getLastTcpHeader(); diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketNIODataService.java b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketNIODataService.java index 07914f9..76ff74d 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketNIODataService.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketNIODataService.java @@ -109,6 +109,16 @@ public void refreshSelect(Session session) { } } + /** + * Run the client-direction pump for a session. The SessionHandler thread calls this after a + * window-opening ACK to flush staged data the window now permits (the upstream socket has no + * new readiness, so the selector wouldn't re-run the pump). Caller holds the session monitor; + * follow with {@link #refreshSelect} to pick up the resumed OP_READ. + */ + public void pumpToClient(Session session){ + reader.pumpToClient(session); + } + /** * Shut down the NIO thread */ diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TcpDownloadFlowControlTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TcpDownloadFlowControlTest.kt new file mode 100644 index 0000000..0588116 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TcpDownloadFlowControlTest.kt @@ -0,0 +1,106 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * Verifies TCP receive-window flow control on the download (server -> app) path: the engine + * must never have more unacknowledged data in flight than the client's advertised window. + * + * Previously the engine ignored the window and pushed the whole response unacknowledged, + * which is why large downloads failed in slow-draining clients (Firefox while Chrome/curl + * succeeded). This drives a download where the client advertises a tiny window and sends NO + * further ACKs, and asserts the engine stops after ~one window instead of over-sending. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class TcpDownloadFlowControlTest { + + private lateinit var harness: ForwardingTestHarness + private lateinit var server: ServerSocket + private val executor = Executors.newSingleThreadExecutor() + private var peerPort = 0 + + private val clientIp = "10.0.0.2" + private val clientPort = 50000 + private val peerIp = "127.0.0.1" + private val clientIsn = 1000L + private val advertisedWindow = 4096 + private val mss = 1460 + + @Before + fun setUp() { + harness = ForwardingTestHarness() + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + } + + @After + fun tearDown() { + executor.shutdownNow() + server.close() + harness.close() + } + + @Test + fun `engine does not send more than the client receive window without ACKs`() { + val acceptedFuture = executor.submit { server.accept() } + + // Handshake, advertising a deliberately tiny receive window. + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn, ack = 0, flags = TestPackets.SYN, mss = mss, + windowSize = advertisedWindow, + ) + ) + val synAck = harness.awaitTunPacketMatching { + val (_, tcp) = harness.parseTcp(it); tcp.isSYN && tcp.isACK + } + val serverIsn = harness.parseTcp(synAck).second.sequenceNumber + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = serverIsn + 1, flags = TestPackets.ACK, + windowSize = advertisedWindow, + ) + ) + val accepted = acceptedFuture.get(3, TimeUnit.SECONDS) + + // The server streams a large response. The client (this test) deliberately sends NO + // further ACKs and never opens its window beyond the 4 KB advertised above. + val responseSize = 256 * 1024 + accepted.getOutputStream().apply { write(ByteArray(responseSize)); flush() } + + // Drain everything the engine pushes to the TUN until it goes quiet, summing payloads. + var unackedBytes = 0L + while (true) { + val pkt = harness.pollTunPacket(1000) ?: break + unackedBytes += tcpPayloadLength(pkt) + } + + // With zero ACKs the engine may keep at most one advertised window in flight, and should + // have filled that window (to within one segment) rather than stalling early. + assertThat(unackedBytes).isAtMost(advertisedWindow.toLong()) + assertThat(unackedBytes).isAtLeast(advertisedWindow.toLong() - mss) + } + + /** Payload length of a captured IPv4 TCP packet, from its header-length fields. */ + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TcpWindowedDownloadTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TcpWindowedDownloadTest.kt new file mode 100644 index 0000000..5a19bd8 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TcpWindowedDownloadTest.kt @@ -0,0 +1,214 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.Future +import java.util.concurrent.TimeUnit + +/** + * End-to-end tests of the windowed download path against a real loopback server: the engine + * must never have more than the client's advertised window of unacknowledged data in flight, + * must resume sending when an ACK reopens the window, must backpressure the upstream so a + * large file isn't pulled into memory, and must deliver the whole file in order to a + * conformant (windowed) client. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class TcpWindowedDownloadTest { + + private lateinit var harness: ForwardingTestHarness + private lateinit var server: ServerSocket + private lateinit var accepted: Socket + private val executor = Executors.newCachedThreadPool() + private var peerPort = 0 + + private val clientIp = "10.0.0.2" + private val clientPort = 50000 + private val peerIp = "127.0.0.1" + private val clientIsn = 1000L + private val mss = 1460 + + @Before + fun setUp() { + harness = ForwardingTestHarness() + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + } + + @After + fun tearDown() { + executor.shutdownNow() + server.close() + harness.close() + } + + @Test + fun `an ACK that reopens the window resumes sending the next in-order segment`() { + val serverIsn = handshake(window = 2000) + accepted.getOutputStream().apply { write(ByteArray(8000)); flush() } + + // Initial burst is capped at one window; nothing more until the client ACKs. + val burst = drain() + val burstBytes = burst.sumOf { it.len.toLong() } + assertThat(burstBytes).isAtMost(2000L) + assertThat(burstBytes).isAtLeast(2000L - mss) + val nextSeq = (serverIsn + 1 + burstBytes) and 0xFFFFFFFFL + + // Open the window by acking the burst. No new upstream data is written, so resumption + // can only come from the handler-driven pump flushing already-staged bytes. + clientAck(serverIsn + 1 + burstBytes, window = 2000) + + val resumed = drain() + assertThat(resumed).isNotEmpty() + assertThat(resumed.first().seq).isEqualTo(nextSeq) + } + + @Test + fun `backpressure bounds the engine's staging buffer instead of pulling the whole file into memory`() { + // A slow client that never acks: the engine must stop reading upstream once its staging + // buffer is full, rather than pulling the entire response into memory. + val window = 4 * 1024 + val total = 4 * 1024 * 1024 // far larger than any bounded staging buffer + val serverIsn = handshake(window) + val key = Session.getSessionKey( + SessionProtocol.TCP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + executor.submit { + runCatching { accepted.getOutputStream().apply { write(ByteArray(total)); flush() } } + } + + val session = harness.await { harness.sessionByKey(key) } + // The engine fills its staging buffer toward the cap, then stops reading. + harness.await { session.takeIf { it.receivingStreamSize() >= 100 * 1024 } } + Thread.sleep(200) // let any in-flight read settle + + // Staging stays bounded (~STAGING_CAP + one read chunk), nowhere near the 4 MB written. + assertThat(session.receivingStreamSize()).isAtMost(256 * 1024) + } + + @Test + fun `a windowed client receives the whole download in order with in-flight bounded by the window`() { + val window = 32 * 1024 + val total = 1024 * 1024 + val serverIsn = handshake(window, mss) + + val writeFuture: Future<*> = executor.submit { + accepted.getOutputStream().apply { write(ByteArray(total)); flush() } + } + + var received = 0L + var lastAck = 0L + var peakInFlight = 0L + var expectedSeq = (serverIsn + 1) and 0xFFFFFFFFL + var guard = 0 + while (received < total && guard++ < 1000) { + for (seg in drain()) { + assertThat(seg.seq).isEqualTo(expectedSeq) // strictly in order, no gaps + expectedSeq = (expectedSeq + seg.len) and 0xFFFFFFFFL + received += seg.len + } + peakInFlight = maxOf(peakInFlight, received - lastAck) + clientAck(serverIsn + 1 + received, window) + lastAck = received + } + + assertThat(received).isEqualTo(total.toLong()) + assertThat(peakInFlight).isAtMost(window.toLong()) + writeFuture.get(3, TimeUnit.SECONDS) // upstream completes once drained + } + + @Test + fun `FIN is not sent ahead of unsent data and follows the last byte`() { + val window = 1000 + val total = 3000 + val serverIsn = handshake(window) + + accepted.getOutputStream().apply { write(ByteArray(total)); flush() } + accepted.close() // upstream EOF, while staged data still exceeds the window + + val segs = ArrayList() + segs += drain() // initial burst: one window, no FIN (data still staged) + var received = segs.sumOf { it.len.toLong() } + var guard = 0 + while (received < total && guard++ < 100) { + clientAck(serverIsn + 1 + received, window) + val batch = drain() + segs += batch + received += batch.sumOf { it.len.toLong() } + } + + val finIdx = segs.indexOfFirst { it.fin } + assertThat(finIdx).isAtLeast(0) // a FIN was eventually sent + // All payload bytes were delivered before the FIN (FIN never overtook unsent data). + assertThat(segs.take(finIdx).sumOf { it.len.toLong() }).isEqualTo(total.toLong()) + // The FIN's sequence sits right after the last data byte. + assertThat(segs[finIdx].seq).isEqualTo((serverIsn + 1 + total) and 0xFFFFFFFFL) + } + + // --- helpers ------------------------------------------------------------- + + private data class Seg(val seq: Long, val len: Int, val fin: Boolean) + + private fun handshake(window: Int, mss: Int = this.mss): Long { + val acceptedFuture = executor.submit { server.accept() } + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn, ack = 0, flags = TestPackets.SYN, mss = mss, windowSize = window, + ) + ) + val synAck = harness.awaitTunPacketMatching { + val (_, tcp) = harness.parseTcp(it); tcp.isSYN && tcp.isACK + } + val serverIsn = harness.parseTcp(synAck).second.sequenceNumber + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = serverIsn + 1, flags = TestPackets.ACK, windowSize = window, + ) + ) + accepted = acceptedFuture.get(3, TimeUnit.SECONDS) + return serverIsn + } + + private fun clientAck(ackNumber: Long, window: Int) { + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = ackNumber, flags = TestPackets.ACK, windowSize = window, + ) + ) + } + + /** Drain TUN packets until the queue is quiet, returning each as a parsed segment. */ + private fun drain(timeoutMs: Long = 1000): List { + val out = ArrayList() + while (true) { + val pkt = harness.pollTunPacket(timeoutMs) ?: break + val (_, tcp) = harness.parseTcp(pkt) + out.add(Seg(tcp.sequenceNumber and 0xFFFFFFFFL, tcpPayloadLength(pkt), tcp.isFIN)) + } + return out + } + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt index c04d5bf..ada7b17 100644 --- a/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt @@ -73,6 +73,7 @@ object TestPackets { fun tcpPacket( srcIp: String, srcPort: Int, dstIp: String, dstPort: Int, seq: Long, ack: Long, flags: Int, payload: ByteArray = ByteArray(0), mss: Int? = null, + windowSize: Int = 65535, ): ByteArray { val optionBytes = if (mss != null) 4 else 0 // MSS option = kind(1)+len(1)+value(2) val tcpHeaderLen = 20 + optionBytes @@ -88,7 +89,7 @@ object TestPackets { buf.putInt(t + 8, ack) buf[t + 12] = (dataOffsetWords shl 4).toByte() // data offset, NS=0 buf[t + 13] = flags.toByte() - buf.putShort(t + 14, 65535) // window size + buf.putShort(t + 14, windowSize) // window size buf.putShort(t + 16, 0) // checksum (not verified) buf.putShort(t + 18, 0) // urgent pointer if (mss != null) { From c541f685c8a6f6d046c13429491213106e137104 Mon Sep 17 00:00:00 2001 From: Jeroen Robben Date: Tue, 16 Jun 2026 02:37:30 +0200 Subject: [PATCH 2/2] add more packet forwarding tests --- .../vpn/ConcurrentDownloadInstrumentedTest.kt | 259 ++++++++++++++++++ .../android/vpn/SessionHandler.java | 3 + .../android/vpn/IcmpForwardingTest.kt | 211 ++++++++++++++ .../vpn/MixedProtocolConcurrencyTest.kt | 218 +++++++++++++++ .../android/vpn/TcpConcurrentWindowsTest.kt | 234 ++++++++++++++++ .../android/vpn/TcpEdgeCaseForwardingTest.kt | 251 +++++++++++++++++ .../httptoolkit/android/vpn/TestPackets.kt | 81 +++++- .../android/vpn/UdpConcurrentFlowsTest.kt | 150 ++++++++++ 8 files changed, 1402 insertions(+), 5 deletions(-) create mode 100644 app/src/androidTest/java/tech/httptoolkit/android/vpn/ConcurrentDownloadInstrumentedTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/IcmpForwardingTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/MixedProtocolConcurrencyTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/TcpConcurrentWindowsTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/TcpEdgeCaseForwardingTest.kt create mode 100644 app/src/test/java/tech/httptoolkit/android/vpn/UdpConcurrentFlowsTest.kt diff --git a/app/src/androidTest/java/tech/httptoolkit/android/vpn/ConcurrentDownloadInstrumentedTest.kt b/app/src/androidTest/java/tech/httptoolkit/android/vpn/ConcurrentDownloadInstrumentedTest.kt new file mode 100644 index 0000000..506e269 --- /dev/null +++ b/app/src/androidTest/java/tech/httptoolkit/android/vpn/ConcurrentDownloadInstrumentedTest.kt @@ -0,0 +1,259 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import androidx.room.Room +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import org.distrinet.lanshield.database.AppDatabase +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import tech.httptoolkit.android.vpn.socket.IProtectSocket +import tech.httptoolkit.android.vpn.socket.SocketNIODataService +import tech.httptoolkit.android.vpn.socket.SocketProtector +import tech.httptoolkit.android.vpn.transport.ip.IPPacketFactory +import tech.httptoolkit.android.vpn.transport.tcp.TCPPacketFactory +import java.io.File +import java.io.FileOutputStream +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.nio.ByteBuffer +import java.util.concurrent.Executors +import java.util.concurrent.Future +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit + +/** + * On-device end-to-end test running several windowed TCP downloads through the real engine at + * once, each with a different receive window. Proves on a real Android TCP stack + NIO selector + * that concurrent flows are tracked and flow-controlled independently: each delivered in order, + * in full, with unacked data never exceeding its own window, and no bytes leaking across ports. + * + * Self-contained (the unit-test harness isn't visible to androidTest): it inlines a capturing + * writer and minimal TCP packet builders. + */ +@RunWith(AndroidJUnit4::class) +class ConcurrentDownloadInstrumentedTest { + + private lateinit var db: AppDatabase + private lateinit var writer: CapturingWriter + private lateinit var nioService: SocketNIODataService + private lateinit var nioThread: Thread + private lateinit var sessionManager: SessionManager + private lateinit var sessionHandler: SessionHandler + private lateinit var tempTun: File + + private val executor = Executors.newCachedThreadPool() + + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + private val mss = 1460 + + private class CapturingWriter(out: FileOutputStream) : ClientPacketWriter(out) { + val queue = LinkedBlockingQueue() + override fun write(data: ByteArray) { queue.add(data) } + } + + @Before + fun setUp() { + SocketProtector.getInstance().setProtector(object : IProtectSocket { + override fun protect(socket: Socket): Boolean = true + override fun protect(socket: DatagramSocket): Boolean = true + }) + val ctx = ApplicationProvider.getApplicationContext() + db = Room.inMemoryDatabaseBuilder(ctx, AppDatabase::class.java).allowMainThreadQueries().build() + + tempTun = File.createTempFile("tun", ".bin").apply { deleteOnExit() } + writer = CapturingWriter(FileOutputStream(tempTun)) + nioService = SocketNIODataService(writer, db) + nioThread = Thread(nioService, "nio-concurrent").apply { isDaemon = true; start() } + sessionManager = SessionManager(db) + sessionHandler = SessionHandler(sessionManager, nioService, writer, db) + } + + @After + fun tearDown() { + executor.shutdownNow() + runCatching { nioService.shutdown() } + runCatching { nioThread.join(1000) } + runCatching { db.close() } + runCatching { tempTun.delete() } + } + + @Test + fun concurrentWindowedDownloadsAreEachDeliveredInOrderWithinTheirWindow() { + val total = 128 * 1024 + val flows = listOf( + Flow(clientPort = 50001, window = 8 * 1024, total = total), + Flow(clientPort = 50002, window = 16 * 1024, total = total), + Flow(clientPort = 50003, window = 32 * 1024, total = total), + ) + val byPort = flows.associateBy { it.clientPort } + + try { + flows.forEach { it.open() } + + // Handshake all flows; demultiplex SYN-ACKs back to each by client port. + flows.forEach { feed(syn(it)) } + repeat(flows.size) { + val pkt = awaitMatching { val (_, t) = parseTcp(it); t.isSYN && t.isACK } + val (_, tcp) = parseTcp(pkt) + val flow = byPort.getValue(tcp.destinationPort) + flow.serverIsn = tcp.sequenceNumber + flow.expectedSeq = (tcp.sequenceNumber + 1) and 0xFFFFFFFFL + } + flows.forEach { feed(ack(it, it.serverIsn + 1)) } + flows.forEach { it.accept() } + + flows.forEach { flow -> + executor.submit { + runCatching { flow.accepted.getOutputStream().apply { write(ByteArray(total)); flush() } } + } + } + + // Drain loop, demuxing by client port and acking each flow at half its window. + var idle = 0 + while (flows.any { !it.complete }) { + val pkt = writer.queue.poll(2000, TimeUnit.MILLISECONDS) + if (pkt == null) { + if (++idle > 3) break + flows.filter { !it.complete }.forEach { feed(ack(it, it.serverIsn + 1 + it.received)); it.lastAck = it.received } + continue + } + idle = 0 + val (ip, tcp) = parseTcp(pkt) + val flow = byPort.getValue(tcp.destinationPort) + assertEquals("reply addressed to wrong client", clientIp, ip.destinationIP.toString()) + val len = tcpPayloadLength(pkt) + if (len <= 0) continue + + assertEquals("segments must arrive in order", flow.expectedSeq, tcp.sequenceNumber and 0xFFFFFFFFL) + flow.expectedSeq = (flow.expectedSeq + len) and 0xFFFFFFFFL + flow.received += len + flow.peakInFlight = maxOf(flow.peakInFlight, flow.received - flow.lastAck) + assertTrue("in-flight ${flow.received - flow.lastAck} exceeded window ${flow.window}", + flow.received - flow.lastAck <= flow.window) + + if (flow.received - flow.lastAck >= flow.window / 2 || flow.received == total.toLong()) { + feed(ack(flow, flow.serverIsn + 1 + flow.received)); flow.lastAck = flow.received + } + } + + flows.forEach { flow -> + assertEquals("flow on ${flow.clientPort} must receive the whole file", + total.toLong(), flow.received) + assertTrue("peak in-flight ${flow.peakInFlight} exceeded window ${flow.window}", + flow.peakInFlight <= flow.window) + } + } finally { + flows.forEach { it.close() } + } + } + + // --- engine plumbing ----------------------------------------------------- + + private fun feed(packet: ByteArray) = sessionHandler.handlePacket(ByteBuffer.wrap(packet), "test.app") + + private fun parseTcp(packet: ByteArray) = ByteBuffer.wrap(packet).let { b -> + IPPacketFactory.createIPHeader(b) to TCPPacketFactory.createTCPHeader(b) + } + + private fun awaitMatching(timeoutMs: Long = 5000, predicate: (ByteArray) -> Boolean): ByteArray { + val deadline = System.nanoTime() + timeoutMs * 1_000_000 + while (System.nanoTime() < deadline) { + val remaining = ((deadline - System.nanoTime()) / 1_000_000).coerceAtLeast(1) + val pkt = writer.queue.poll(remaining, TimeUnit.MILLISECONDS) ?: break + if (predicate(pkt)) return pkt + } + throw AssertionError("No matching TUN packet within ${timeoutMs}ms") + } + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } + + // --- per-flow state ------------------------------------------------------ + + private inner class Flow(val clientPort: Int, val window: Int, val total: Int) { + lateinit var server: ServerSocket + var peerPort = 0 + lateinit var accepted: Socket + private lateinit var acceptFuture: Future + var serverIsn = 0L + var expectedSeq = 0L + var received = 0L + var lastAck = 0L + var peakInFlight = 0L + val complete get() = received >= total + + fun open() { + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + acceptFuture = executor.submit { server.accept() } + } + + fun accept() { accepted = acceptFuture.get(5, TimeUnit.SECONDS) } + + fun close() { + runCatching { if (this::accepted.isInitialized) accepted.close() } + runCatching { server.close() } + } + } + + private fun syn(flow: Flow): ByteArray = tcpPacket( + flow.clientPort, flow.peerPort, seq = 1000L, ack = 0, flags = SYN, mss = mss, win = flow.window, + ) + + private fun ack(flow: Flow, ackNum: Long): ByteArray = tcpPacket( + flow.clientPort, flow.peerPort, seq = 1001L, ack = ackNum, flags = ACK, mss = null, win = flow.window, + ) + + // --- raw IPv4 TCP packet builder ----------------------------------------- + + private val SYN = 0x02 + private val ACK = 0x10 + + private fun tcpPacket(srcPort: Int, dstPort: Int, seq: Long, ack: Long, flags: Int, mss: Int?, win: Int): ByteArray { + val optionBytes = if (mss != null) 4 else 0 + val tcpHeaderLen = 20 + optionBytes + val total = 20 + tcpHeaderLen + val buf = ByteArray(total) + buf[0] = 0x45 + putShort(buf, 2, total) + buf[8] = 64 + buf[9] = 6 // protocol TCP + ipBytes(clientIp).copyInto(buf, 12) + ipBytes(peerIp).copyInto(buf, 16) + val t = 20 + putShort(buf, t, srcPort) + putShort(buf, t + 2, dstPort) + putInt(buf, t + 4, seq) + putInt(buf, t + 8, ack) + buf[t + 12] = ((tcpHeaderLen / 4) shl 4).toByte() + buf[t + 13] = flags.toByte() + putShort(buf, t + 14, win) + if (mss != null) { + buf[t + 20] = 0x02; buf[t + 21] = 0x04; putShort(buf, t + 22, mss) + } + return buf + } + + private fun ipBytes(dotted: String) = ByteArray(4) { dotted.split(".")[it].toInt().toByte() } + + private fun putShort(b: ByteArray, off: Int, v: Int) { + b[off] = ((v ushr 8) and 0xFF).toByte(); b[off + 1] = (v and 0xFF).toByte() + } + + private fun putInt(b: ByteArray, off: Int, v: Long) { + b[off] = ((v ushr 24) and 0xFF).toByte(); b[off + 1] = ((v ushr 16) and 0xFF).toByte() + b[off + 2] = ((v ushr 8) and 0xFF).toByte(); b[off + 3] = (v and 0xFF).toByte() + } +} diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java b/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java index cf3e5dc..211cb22 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/SessionHandler.java @@ -276,6 +276,9 @@ private void ackFinAck(IPHeader ip, TCPHeader tcp, Session session){ if(session != null){ session.cancelKey(); manager.closeSession(session); + // Abort so the keepSessionAlive block below doesn't re-insert this just-closed + // session as a dead-channel zombie (which would block reuse of the client port). + session.setAbortingConnection(true); Log.d(TAG,"ACK to client's FIN and close session => "+ip.getDestinationIP().toString()+":"+tcp.getDestinationPort() +"-"+ip.getSourceIP().toString()+":"+tcp.getSourcePort()); } diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/IcmpForwardingTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/IcmpForwardingTest.kt new file mode 100644 index 0000000..a29f7be --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/IcmpForwardingTest.kt @@ -0,0 +1,211 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Assert.assertThrows +import org.junit.Assume.assumeTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.PacketHeaderException +import tech.httptoolkit.android.vpn.transport.icmp.ICMPPacketFactory +import tech.httptoolkit.android.vpn.transport.ip.IPPacketFactory +import java.nio.ByteBuffer + +/** + * Tests the engine's ICMP handling: echo requests proxied to a reachable host and the reply + * returned to the client (live, against loopback), reply construction (factory round-trip), + * types we can't proxy dropped/rejected, and that ICMP is never connection-tracked. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class IcmpForwardingTest { + + private lateinit var harness: ForwardingTestHarness + + private val clientIp = "10.0.0.2" + private val clientIpv6 = "fd00::2" + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + harness.close() + } + + // Live echo depends on isReachable, so these skip if the env blocks loopback reachability. + + @Test + fun `ipv4 echo request to a reachable host returns an echo reply to the client`() { + val payload = "ping-payload".toByteArray() + harness.feed( + TestPackets.icmpPacket( + clientIp, "127.0.0.1", TestPackets.ICMP_V4_ECHO_REQUEST, 0, + identifier = 0x1234, seq = 7, payload = payload, + ) + ) + + val reply = harness.pollTunPacket(5000) + assumeTrue("no ICMP reply — isReachable(127.0.0.1) likely blocked in this env", reply != null) + + val icmp = parseIcmp(reply!!) + assertThat(icmp.type).isEqualTo(TestPackets.ICMP_V4_ECHO_REPLY) + assertThat(icmp.srcIp).isEqualTo("127.0.0.1") // src/dst swapped + assertThat(icmp.dstIp).isEqualTo(clientIp) + assertThat(icmp.identifier).isEqualTo(0x1234) + assertThat(icmp.seq).isEqualTo(7) + assertThat(icmp.payload).isEqualTo(payload) + } + + @Test + fun `ipv6 echo request to a reachable host returns an echo reply`() { + val payload = "v6".toByteArray() + harness.feed( + TestPackets.icmpv6Packet( + clientIpv6, "::1", TestPackets.ICMP_V6_ECHO_REQUEST, 0, + identifier = 0x55, seq = 3, payload = payload, + ) + ) + + val reply = harness.pollTunPacket(5000) + assumeTrue("no ICMPv6 reply — isReachable(::1) likely blocked in this env", reply != null) + + val icmp = parseIcmp(reply!!) + assertThat(icmp.type).isEqualTo(TestPackets.ICMP_V6_ECHO_REPLY) + assertThat(icmp.identifier).isEqualTo(0x55) + assertThat(icmp.seq).isEqualTo(3) + assertThat(icmp.payload).isEqualTo(payload) + } + + @Test + fun `buildSuccessPacket and packetToBuffer produce a valid echo reply that echoes the request`() { + val payload = "abcd".toByteArray() + val requestBytes = TestPackets.icmpPacket( + clientIp, "127.0.0.1", TestPackets.ICMP_V4_ECHO_REQUEST, 0, + identifier = 0xBEEF, seq = 42, payload = payload, + ) + val buffer = ByteBuffer.wrap(requestBytes) + val ipHeader = IPPacketFactory.createIPHeader(buffer) + val request = ICMPPacketFactory.parseICMPPacket(4, buffer) + + val reply = ICMPPacketFactory.buildSuccessPacket(4, request) + val replyBytes = ICMPPacketFactory.packetToBuffer(ipHeader, reply) + + val parsed = parseIcmp(replyBytes) + assertThat(parsed.type).isEqualTo(TestPackets.ICMP_V4_ECHO_REPLY) + assertThat(parsed.identifier).isEqualTo(0xBEEF) + assertThat(parsed.seq).isEqualTo(42) + assertThat(parsed.payload).isEqualTo(payload) + assertThat(icmpChecksumIsValid(replyBytes)).isTrue() + } + + @Test + fun `destination-unreachable icmp is dropped silently`() { + harness.feed( + TestPackets.icmpPacket( + clientIp, "127.0.0.1", TestPackets.ICMP_V4_DEST_UNREACHABLE, 1, + identifier = 1, seq = 1, + ) + ) + assertThat(harness.pollTunPacket(300)).isNull() + } + + @Test + fun `ipv6 router solicitation is dropped silently`() { + harness.feed( + TestPackets.icmpv6Packet( + clientIpv6, "ff02::2", TestPackets.ICMP_V6_ROUTER_SOLICITATION, 0, + identifier = 0, seq = 0, + ) + ) + assertThat(harness.pollTunPacket(300)).isNull() + } + + @Test + fun `an unsupported icmp type is rejected with a PacketHeaderException`() { + // type 11 (time exceeded): not unreachable, router-solicit, or echo request. + assertThrows(PacketHeaderException::class.java) { + harness.feed( + TestPackets.icmpPacket(clientIp, "127.0.0.1", 11, 0, identifier = 0, seq = 0) + ) + } + } + + @Test + fun `icmp is not connection-tracked - no session or flow is created`() { + harness.feed( + TestPackets.icmpPacket( + clientIp, "127.0.0.1", TestPackets.ICMP_V4_ECHO_REQUEST, 0, + identifier = 9, seq = 9, + ) + ) + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(0) + } + + @Test + fun `concurrent echo requests each get a correlated reply`() { + val ids = listOf(0xA1 to 1, 0xA2 to 2, 0xA3 to 3) + ids.forEach { (id, seq) -> + harness.feed( + TestPackets.icmpPacket( + clientIp, "127.0.0.1", TestPackets.ICMP_V4_ECHO_REQUEST, 0, + identifier = id, seq = seq, payload = "p$seq".toByteArray(), + ) + ) + } + + // The ping pool may discard under flood, so require only that every reply seen + // correlates to a request (and at least one comes back). + val seen = mutableSetOf>() + while (true) { + val pkt = harness.pollTunPacket(3000) ?: break + val icmp = parseIcmp(pkt) + assertThat(icmp.type).isEqualTo(TestPackets.ICMP_V4_ECHO_REPLY) + val key = icmp.identifier to icmp.seq + assertThat(ids).contains(key) + seen.add(key) + if (seen.size == ids.size) break + } + assumeTrue("no ICMP replies — isReachable(127.0.0.1) likely blocked in this env", seen.isNotEmpty()) + } + + // --- helpers ------------------------------------------------------------- + + private data class ParsedIcmp( + val srcIp: String, val dstIp: String, + val type: Int, val identifier: Int, val seq: Int, val payload: ByteArray, + ) + + private fun parseIcmp(packet: ByteArray): ParsedIcmp { + val ip = IPPacketFactory.createIPHeader(ByteBuffer.wrap(packet)) + val version = packet[0].toInt() shr 4 and 0x0F + val off = if (version == 4) (packet[0].toInt() and 0x0F) * 4 else 40 + val type = packet[off].toInt() and 0xFF + val identifier = ((packet[off + 4].toInt() and 0xFF) shl 8) or (packet[off + 5].toInt() and 0xFF) + val seq = ((packet[off + 6].toInt() and 0xFF) shl 8) or (packet[off + 7].toInt() and 0xFF) + val payload = packet.copyOfRange(off + 8, packet.size) + return ParsedIcmp(ip.sourceIP.toString(), ip.destinationIP.toString(), type, identifier, seq, payload) + } + + /** True when the one's-complement checksum over the ICMP message folds to zero. */ + private fun icmpChecksumIsValid(packet: ByteArray): Boolean { + val version = packet[0].toInt() shr 4 and 0x0F + val off = if (version == 4) (packet[0].toInt() and 0x0F) * 4 else 40 + var sum = 0L + var i = off + while (i < packet.size) { + val hi = packet[i].toInt() and 0xFF + val lo = if (i + 1 < packet.size) packet[i + 1].toInt() and 0xFF else 0 + sum += ((hi shl 8) or lo).toLong() + i += 2 + } + while (sum shr 16 != 0L) sum = (sum and 0xFFFF) + (sum shr 16) + return sum.toInt() and 0xFFFF == 0xFFFF + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/MixedProtocolConcurrencyTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/MixedProtocolConcurrencyTest.kt new file mode 100644 index 0000000..736f6ac --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/MixedProtocolConcurrencyTest.kt @@ -0,0 +1,218 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * Stress / consistency check: TCP downloads, UDP echo flows and ICMP pings run through the one + * engine at the same time, exercising the shared NIO thread and session table under mixed load. + * It asserts the protocols never bleed into each other — each TCP download in order on its own + * port, each UDP reply on its own port, ICMP answered but never connection-tracked. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class MixedProtocolConcurrencyTest { + + private lateinit var harness: ForwardingTestHarness + private val executor = Executors.newCachedThreadPool() + + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + private val mss = 1460 + private val tcpWindow = 16 * 1024 + private val tcpTotal = 64 * 1024 + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + executor.shutdownNow() + harness.close() + } + + @Test + fun `tcp udp and icmp flows run concurrently without crossing streams`() { + val tcpPorts = listOf(54000, 54001, 54002) + val udpPorts = listOf(42000, 42001, 42002) + val pingIds = listOf(0xC1, 0xC2, 0xC3) + + val tcpFlows = tcpPorts.map { TcpFlow(it) } + val tcpByPort = tcpFlows.associateBy { it.clientPort } + val udpPeer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val udpPeerPort = udpPeer.localPort + + try { + // Handshake the TCP flows first (clean SYN-ACK demux), then accept their sockets. + tcpFlows.forEach { it.open() } + tcpFlows.forEach { feedSyn(it) } + repeat(tcpFlows.size) { + val (ip, tcp) = harness.parseTcp( + harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + ) + val flow = tcpByPort.getValue(tcp.destinationPort) + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) + flow.serverIsn = tcp.sequenceNumber + flow.expectedSeq = (tcp.sequenceNumber + 1) and 0xFFFFFFFFL + } + tcpFlows.forEach { feedAck(it, it.serverIsn + 1) } + tcpFlows.forEach { it.accept() } + + // Now drive all three protocols at once: TCP server writes, UDP datagrams, ICMP pings. + val echo = Thread { + repeat(udpPorts.size) { + val rx = DatagramPacket(ByteArray(128), 128) + udpPeer.receive(rx) + udpPeer.send(DatagramPacket(rx.data, rx.length, rx.socketAddress)) + } + }.apply { isDaemon = true; start() } + + tcpFlows.forEach { flow -> + executor.submit { + runCatching { flow.accepted.getOutputStream().apply { write(ByteArray(tcpTotal)); flush() } } + } + } + udpPorts.forEach { port -> harness.feed(udp(port, udpPeerPort, "u-$port")) } + pingIds.forEach { id -> + harness.feed(TestPackets.icmpPacket(clientIp, peerIp, TestPackets.ICMP_V4_ECHO_REQUEST, 0, id, 1)) + } + + // One drain loop classifies every captured packet by protocol and routes it. + val udpReplies = HashMap() + val icmpReplies = HashSet() + var idle = 0 + while ((tcpFlows.any { !it.complete } || udpReplies.size < udpPorts.size)) { + val pkt = harness.pollTunPacket(1000) + if (pkt == null) { + if (++idle > 3) break + tcpFlows.filter { !it.complete }.forEach { feedAck(it, it.serverIsn + 1 + it.received); it.lastAck = it.received } + continue + } + idle = 0 + when (ipProtocol(pkt)) { + 6 -> { + val (_, tcp) = harness.parseTcp(pkt) + val flow = tcpByPort.getValue(tcp.destinationPort) // only TCP client ports + val len = tcpPayloadLength(pkt) + if (len <= 0) continue + assertThat(tcp.sequenceNumber and 0xFFFFFFFFL).isEqualTo(flow.expectedSeq) + flow.expectedSeq = (flow.expectedSeq + len) and 0xFFFFFFFFL + flow.received += len + if (flow.received - flow.lastAck >= flow.window / 2 || flow.received == tcpTotal.toLong()) { + feedAck(flow, flow.serverIsn + 1 + flow.received); flow.lastAck = flow.received + } + } + 17 -> { + val (_, udp, payload) = harness.parseUdp(pkt) + assertThat(udpPorts).contains(udp.destinationPort) // only UDP client ports + udpReplies[udp.destinationPort] = String(payload) + } + 1 -> { + val off = (pkt[0].toInt() and 0x0F) * 4 + assertThat(pkt[off].toInt() and 0xFF).isEqualTo(TestPackets.ICMP_V4_ECHO_REPLY) + icmpReplies.add(((pkt[off + 4].toInt() and 0xFF) shl 8) or (pkt[off + 5].toInt() and 0xFF)) + } + } + } + + // TCP delivered in full and in order (ordering guarded by the seq checks above). + tcpFlows.forEach { assertThat(it.received).isEqualTo(tcpTotal.toLong()) } + assertThat(udpReplies).isEqualTo(udpPorts.associateWith { "u-$it" }) + // Any ICMP replies that arrived correlate only to our pings. + assertThat(pingIds).containsAtLeastElementsIn(icmpReplies) + + // One session/flow per TCP and UDP flow; ICMP contributed none. + tcpFlows.forEach { assertThat(harness.sessionByKey(tcpKey(it.clientPort, it.peerPort))).isNotNull() } + udpPorts.forEach { assertThat(harness.sessionByKey(udpKey(it, udpPeerPort))).isNotNull() } + assertThat(harness.flowDao.countNotSyncedFlows()) + .isEqualTo((tcpPorts.size + udpPorts.size).toLong()) + + echo.join(2000) + } finally { + tcpFlows.forEach { it.close() } + udpPeer.close() + } + } + + // --- helpers ------------------------------------------------------------- + + private inner class TcpFlow(val clientPort: Int) { + lateinit var server: ServerSocket + var peerPort = 0 + lateinit var accepted: Socket + private lateinit var acceptFuture: java.util.concurrent.Future + val window = tcpWindow + var serverIsn = 0L + var expectedSeq = 0L + var received = 0L + var lastAck = 0L + val complete get() = received >= tcpTotal + + fun open() { + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + acceptFuture = executor.submit { server.accept() } + } + + fun accept() { accepted = acceptFuture.get(3, TimeUnit.SECONDS) } + + fun close() { + runCatching { if (this::accepted.isInitialized) accepted.close() } + runCatching { server.close() } + } + } + + private fun feedSyn(flow: TcpFlow) = harness.feed( + TestPackets.tcpPacket( + clientIp, flow.clientPort, peerIp, flow.peerPort, + seq = 1000L, ack = 0, flags = TestPackets.SYN, mss = mss, windowSize = flow.window, + ) + ) + + private fun feedAck(flow: TcpFlow, ackNumber: Long) = harness.feed( + TestPackets.tcpPacket( + clientIp, flow.clientPort, peerIp, flow.peerPort, + seq = 1001L, ack = ackNumber, flags = TestPackets.ACK, windowSize = flow.window, + ) + ) + + private fun udp(clientPort: Int, peerPort: Int, payload: String): ByteArray = + TestPackets.udpPacket(clientIp, clientPort, peerIp, peerPort, payload.toByteArray()) + + private fun ipProtocol(packet: ByteArray): Int = packet[9].toInt() and 0xFF + + private fun tcpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.TCP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun udpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.UDP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TcpConcurrentWindowsTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TcpConcurrentWindowsTest.kt new file mode 100644 index 0000000..c0d8699 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TcpConcurrentWindowsTest.kt @@ -0,0 +1,234 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * Drives many concurrent TCP downloads through the one engine, each advertising a *different* + * receive window. All flows share one capture queue and are demultiplexed by destination port, + * so the test also proves connection tracking keeps the streams separate: each flow is + * delivered strictly in order, in full, with unacked data never exceeding its own window, and + * a client port can be reused once the previous connection is torn down. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class TcpConcurrentWindowsTest { + + private lateinit var harness: ForwardingTestHarness + private val executor = Executors.newCachedThreadPool() + + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + private val mss = 1460 + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + executor.shutdownNow() + harness.close() + } + + @Test + fun `concurrent downloads with distinct windows are each delivered in order within their window`() { + // Distinct windows, all within the unscaled 16-bit TCP window field (>64K would need the + // window-scale option, which these packets don't carry). + val windows = listOf(2 * 1024, 4 * 1024, 8 * 1024, 16 * 1024, 32 * 1024, 65535) + val total = 128 * 1024 + val flows = windows.mapIndexed { i, window -> Flow(clientPort = 51000 + i, window = window, total = total) } + val byPort = flows.associateBy { it.clientPort } + + try { + // Each flow gets its own loopback server, so an accepted socket maps to one flow. + flows.forEach { it.open() } + + // Handshake all flows, demultiplexing the SYN-ACKs back to each by client port. + flows.forEach { feedSyn(it) } + repeat(flows.size) { + val pkt = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + val (ip, tcp) = harness.parseTcp(pkt) + val flow = byPort.getValue(tcp.destinationPort) + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) + flow.serverIsn = tcp.sequenceNumber + flow.expectedSeq = (tcp.sequenceNumber + 1) and 0xFFFFFFFFL + } + flows.forEach { feedAck(it, it.serverIsn + 1) } + flows.forEach { it.accept() } + + flows.forEach { flow -> + executor.submit { + runCatching { flow.accepted.getOutputStream().apply { write(ByteArray(flow.total)); flush() } } + } + } + + // Drain loop, demuxing by client port. Ack a flow only at ~half its window so + // in-flight grows toward (never past) the window; nudge idle flows in case the + // engine is waiting on an ACK. + var idle = 0 + while (flows.any { !it.complete }) { + val pkt = harness.pollTunPacket(1000) + if (pkt == null) { + if (++idle > 3) break + flows.filter { !it.complete }.forEach { + feedAck(it, it.serverIsn + 1 + it.received); it.lastAck = it.received + } + continue + } + idle = 0 + val (ip, tcp) = harness.parseTcp(pkt) + val flow = byPort.getValue(tcp.destinationPort) + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) // no cross-flow leakage + val len = tcpPayloadLength(pkt) + if (len <= 0) continue + + assertThat(tcp.sequenceNumber and 0xFFFFFFFFL).isEqualTo(flow.expectedSeq) + flow.expectedSeq = (flow.expectedSeq + len) and 0xFFFFFFFFL + flow.received += len + flow.peakInFlight = maxOf(flow.peakInFlight, flow.received - flow.lastAck) + + if (flow.received - flow.lastAck >= flow.window / 2 || flow.received == flow.total.toLong()) { + feedAck(flow, flow.serverIsn + 1 + flow.received) + flow.lastAck = flow.received + } + } + + flows.forEach { flow -> + assertThat(flow.received).isEqualTo(flow.total.toLong()) + assertThat(flow.peakInFlight).isAtMost(flow.window.toLong()) + assertThat(harness.sessionByKey(tcpKey(flow.clientPort, flow.peerPort))).isNotNull() + } + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(flows.size.toLong()) + } finally { + flows.forEach { it.close() } + } + } + + @Test + fun `a client port is reused for a fresh connection after the previous one is torn down`() { + val server1 = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + val peerPort = server1.localPort + val clientPort = 52000 + try { + val accept1 = executor.submit { server1.accept() } + handshake(clientPort, peerPort, isn = 1000L) + val socket1 = accept1.get(3, TimeUnit.SECONDS) + assertThat(harness.sessionByKey(tcpKey(clientPort, peerPort))).isNotNull() + + // RST marks the session aborting; closing the upstream then lets the NIO read path + // tear it down and drop it from the table. + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = 1001L, ack = 0, flags = TestPackets.RST, + ) + ) + socket1.close() + harness.await { if (harness.sessionByKey(tcpKey(clientPort, peerPort)) == null) Unit else null } + + // A new SYN on the same client port opens a fresh, independent session. + val accept2 = executor.submit { server1.accept() } + val isn2 = 9000L + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = isn2, ack = 0, flags = TestPackets.SYN, mss = mss, + ) + ) + val synAck = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + assertThat(harness.parseTcp(synAck).second.ackNumber).isEqualTo(isn2 + 1) + accept2.get(3, TimeUnit.SECONDS) + assertThat(harness.sessionByKey(tcpKey(clientPort, peerPort))).isNotNull() + } finally { + server1.close() + } + } + + // --- helpers ------------------------------------------------------------- + + private inner class Flow(val clientPort: Int, val window: Int, val total: Int) { + lateinit var server: ServerSocket + var peerPort = 0 + lateinit var accepted: Socket + private lateinit var acceptFuture: java.util.concurrent.Future + var serverIsn = 0L + var expectedSeq = 0L + var received = 0L + var lastAck = 0L + var peakInFlight = 0L + val complete get() = received >= total + + fun open() { + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + acceptFuture = executor.submit { server.accept() } + } + + fun accept() { accepted = acceptFuture.get(3, TimeUnit.SECONDS) } + + fun close() { + runCatching { if (this::accepted.isInitialized) accepted.close() } + runCatching { server.close() } + } + } + + private fun feedSyn(flow: Flow) = harness.feed( + TestPackets.tcpPacket( + clientIp, flow.clientPort, peerIp, flow.peerPort, + seq = 1000L, ack = 0, flags = TestPackets.SYN, mss = mss, windowSize = flow.window, + ) + ) + + private fun feedAck(flow: Flow, ackNumber: Long) = harness.feed( + TestPackets.tcpPacket( + clientIp, flow.clientPort, peerIp, flow.peerPort, + seq = 1001L, ack = ackNumber, flags = TestPackets.ACK, windowSize = flow.window, + ) + ) + + /** Full handshake on a single (port, peer); returns the engine's server ISN. */ + private fun handshake(clientPort: Int, peerPort: Int, isn: Long): Long { + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = isn, ack = 0, flags = TestPackets.SYN, mss = mss, + ) + ) + val synAck = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + val serverIsn = harness.parseTcp(synAck).second.sequenceNumber + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = isn + 1, ack = serverIsn + 1, flags = TestPackets.ACK, + ) + ) + return serverIsn + } + + private fun tcpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.TCP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TcpEdgeCaseForwardingTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TcpEdgeCaseForwardingTest.kt new file mode 100644 index 0000000..bbe7cb9 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TcpEdgeCaseForwardingTest.kt @@ -0,0 +1,251 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.net.SocketTimeoutException +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * Adversarial / edge-case checks on the TCP path: 32-bit sequence wraparound in the ACK + * accounting, a zero receive window that later reopens, RST teardown, FIN acknowledgement, + * out-of-order (duplicate) client data, and a retransmitted SYN. These are the corners where + * the sequence arithmetic and connection-tracking are most likely to misbehave. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class TcpEdgeCaseForwardingTest { + + private lateinit var harness: ForwardingTestHarness + private lateinit var server: ServerSocket + private val executor = Executors.newCachedThreadPool() + private var peerPort = 0 + + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + private val clientPort = 53000 + private val clientIsn = 1000L + private val mss = 1460 + + @Before + fun setUp() { + harness = ForwardingTestHarness() + server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + peerPort = server.localPort + } + + @After + fun tearDown() { + executor.shutdownNow() + server.close() + harness.close() + } + + @Test + fun `acceptAck handles a sequence number that has wrapped past 2^32`() { + val accept = executor.submit { server.accept() } + handshake(window = 65535) + accept.get(3, TimeUnit.SECONDS) + val session = harness.sessionByKey(tcpKey())!! + + // Seed the send accounting straddling the wrap (sendUnack at 2^32, sendNext just past). + // The client acks the absolute sendNext, whose wire value 0x100 is below the prior wire + // ack — the case the old signed compare mishandled. + session.setSendUnack(0x1_0000_0000L) + session.setSendNext(0x1_0000_0100L) + + feedAck(ackNumber = 0x0000_0100L, window = 65535) + + assertThat(session.getSendUnack()).isEqualTo(0x1_0000_0100L) + } + + @Test + fun `a zero window holds back data until an ACK reopens it`() { + val accept = executor.submit { server.accept() } + val serverIsn = handshake(window = 0) + val socket = accept.get(3, TimeUnit.SECONDS) + + // Server has data, but the client advertised a zero window: nothing may be sent. + socket.getOutputStream().apply { write(ByteArray(8000)); flush() } + assertThat(firstDataSegment(timeoutMs = 500)).isNull() + + // Opening the window releases the staged data, starting at the first unsent byte. + feedAck(ackNumber = serverIsn + 1, window = 4096) + val seg = harness.await { firstDataSegment(timeoutMs = 1000) } + assertThat(seg.seq).isEqualTo((serverIsn + 1) and 0xFFFFFFFFL) + } + + @Test + fun `a client RST marks the connection aborting`() { + val accept = executor.submit { server.accept() } + handshake(window = 65535) + accept.get(3, TimeUnit.SECONDS) + val session = harness.sessionByKey(tcpKey())!! + + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = 0, flags = TestPackets.RST, + ) + ) + assertThat(session.isAbortingConnection()).isTrue() + } + + @Test + fun `a client FIN is acked, tears down the session, and frees the port for reuse`() { + val accept = executor.submit { server.accept() } + val serverIsn = handshake(window = 65535) + accept.get(3, TimeUnit.SECONDS) + assertThat(harness.sessionByKey(tcpKey())).isNotNull() + + // Client closes with FIN+ACK; the engine acks it to the right client endpoint... + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = serverIsn + 1, flags = TestPackets.FIN or TestPackets.ACK, + ) + ) + val finAck = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isFIN && t.isACK } + val (ip, tcp) = harness.parseTcp(finAck) + assertThat(tcp.destinationPort).isEqualTo(clientPort) + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) + + // ...and fully removes the session, rather than re-adding it as a closed-channel zombie. + assertThat(harness.sessionByKey(tcpKey())).isNull() + + // The same client port now establishes a fresh connection (the old entry is gone). + val accept2 = executor.submit { server.accept() } + val isn2 = 7000L + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = isn2, ack = 0, flags = TestPackets.SYN, mss = mss, + ) + ) + val synAck = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + assertThat(harness.parseTcp(synAck).second.ackNumber).isEqualTo(isn2 + 1) + accept2.get(3, TimeUnit.SECONDS) + assertThat(harness.sessionByKey(tcpKey())).isNotNull() + } + + @Test + fun `duplicate out-of-order client data is not re-delivered upstream`() { + val accept = executor.submit { server.accept() } + handshake(window = 65535) + val socket = accept.get(3, TimeUnit.SECONDS) + val session = harness.sessionByKey(tcpKey())!! + + // In-order segment: pushed upstream, advancing recSequence. + feedData(seq = clientIsn + 1, payload = "AAAA") + val upstream = socket.getInputStream() + val first = ByteArray(4).also { readFully(upstream, it) } + assertThat(String(first)).isEqualTo("AAAA") + val advancedRecSeq = session.getRecSequence() + assertThat(advancedRecSeq).isEqualTo(clientIsn + 1 + 4) + + // The same segment again sits below recSequence: a duplicate, neither forwarded upstream + // again nor allowed to move recSequence. + feedData(seq = clientIsn + 1, payload = "AAAA") + socket.soTimeout = 500 + assertThat(runCatching { upstream.read() }.exceptionOrNull()) + .isInstanceOf(SocketTimeoutException::class.java) + assertThat(session.getRecSequence()).isEqualTo(advancedRecSeq) + } + + @Test + fun `a retransmitted SYN does not create a second session or flow`() { + val accept = executor.submit { server.accept() } + val syn = TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn, ack = 0, flags = TestPackets.SYN, mss = mss, + ) + + harness.feed(syn) + harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + accept.get(3, TimeUnit.SECONDS) + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(1) + + // Retransmitted SYN: the engine re-acks rather than opening a second session/flow. + harness.feed(syn) + val reply = harness.awaitTunPacket() + assertThat(harness.parseTcp(reply).second.isSYN).isFalse() + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(1) + assertThat(harness.sessionByKey(tcpKey())).isNotNull() + } + + // --- helpers ------------------------------------------------------------- + + private data class Seg(val seq: Long, val len: Int) + + private fun handshake(window: Int): Long { + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn, ack = 0, flags = TestPackets.SYN, mss = mss, windowSize = window, + ) + ) + val synAck = harness.awaitTunPacketMatching { val (_, t) = harness.parseTcp(it); t.isSYN && t.isACK } + val serverIsn = harness.parseTcp(synAck).second.sequenceNumber + feedAck(ackNumber = serverIsn + 1, window = window) + return serverIsn + } + + private fun feedAck(ackNumber: Long, window: Int) = harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = ackNumber, flags = TestPackets.ACK, windowSize = window, + ) + ) + + private fun feedData(seq: Long, payload: String) = harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = seq, ack = 0, flags = TestPackets.ACK or TestPackets.PSH, + payload = payload.toByteArray(), windowSize = 65535, + ) + ) + + /** Next captured segment carrying payload, or null if none arrives in time. */ + private fun firstDataSegment(timeoutMs: Long): Seg? { + val deadline = System.nanoTime() + timeoutMs * 1_000_000 + while (System.nanoTime() < deadline) { + val remaining = ((deadline - System.nanoTime()) / 1_000_000).coerceAtLeast(1) + val pkt = harness.pollTunPacket(remaining) ?: return null + val len = tcpPayloadLength(pkt) + if (len > 0) return Seg(harness.parseTcp(pkt).second.sequenceNumber and 0xFFFFFFFFL, len) + } + return null + } + + private fun readFully(input: java.io.InputStream, buf: ByteArray) { + var off = 0 + while (off < buf.size) { + val n = input.read(buf, off, buf.size - off) + if (n < 0) break + off += n + } + } + + private fun tcpKey(): String = Session.getSessionKey( + SessionProtocol.TCP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun tcpPayloadLength(packet: ByteArray): Int { + val ihl = (packet[0].toInt() and 0x0F) * 4 + val totalLength = ((packet[2].toInt() and 0xFF) shl 8) or (packet[3].toInt() and 0xFF) + val dataOffset = ((packet[ihl + 12].toInt() shr 4) and 0x0F) * 4 + return totalLength - ihl - dataOffset + } +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt b/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt index ada7b17..ef39904 100644 --- a/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt +++ b/app/src/test/java/tech/httptoolkit/android/vpn/TestPackets.kt @@ -1,13 +1,13 @@ package tech.httptoolkit.android.vpn +import tech.httptoolkit.android.vpn.util.PacketUtil +import java.net.InetAddress import java.nio.ByteBuffer /** - * Builders for raw IPv4 TCP/UDP packets used to drive the forwarding engine in tests. - * - * The engine parses inbound packets but does NOT verify their checksums, so we leave - * checksum fields zero. Length fields, however, must be correct (the parsers and the - * session code rely on them). + * Builders for raw IPv4/IPv6 TCP/UDP/ICMP packets used to drive the forwarding engine in + * tests. Checksums aren't verified by the engine so TCP/UDP leave them zero; length fields + * must be correct. */ object TestPackets { @@ -18,12 +18,26 @@ object TestPackets { const val PSH = 0x08 const val ACK = 0x10 + const val ICMP_V4_ECHO_REQUEST = 8 + const val ICMP_V4_ECHO_REPLY = 0 + const val ICMP_V4_DEST_UNREACHABLE = 3 + const val ICMP_V6_ECHO_REQUEST = 128 + const val ICMP_V6_ECHO_REPLY = 129 + const val ICMP_V6_ROUTER_SOLICITATION = 133 + fun ip(dotted: String): ByteArray { val parts = dotted.split(".") require(parts.size == 4) { "Only IPv4 literals supported: $dotted" } return ByteArray(4) { parts[it].toInt().toByte() } } + /** 16-byte IPv6 address from a literal (e.g. "::1", "fd00::2"). */ + fun ipv6(literal: String): ByteArray { + val bytes = InetAddress.getByName(literal).address + require(bytes.size == 16) { "Not an IPv6 literal: $literal" } + return bytes + } + private fun ByteArray.putShort(offset: Int, value: Int) { this[offset] = ((value ushr 8) and 0xFF).toByte() this[offset + 1] = (value and 0xFF).toByte() @@ -104,4 +118,61 @@ object TestPackets { /** Tail [n] bytes of a captured packet as a String (the transport payload). */ fun payloadString(packet: ByteArray, n: Int): String = String(packet, packet.size - n, n) + + /** IPv4 ICMP packet (protocol 1): 8-byte ICMP header + [payload], with a real checksum. */ + fun icmpPacket( + srcIp: String, dstIp: String, type: Int, code: Int, + identifier: Int, seq: Int, payload: ByteArray = ByteArray(0), + ): ByteArray { + val icmpLen = 8 + payload.size + val total = 20 + icmpLen + val buf = ByteArray(total) + writeIpv4Header(buf, total, 1, ip(srcIp), ip(dstIp)) + + val i = 20 + buf[i] = type.toByte() + buf[i + 1] = code.toByte() + buf.putShort(i + 4, identifier) + buf.putShort(i + 6, seq) + System.arraycopy(payload, 0, buf, i + 8, payload.size) + + val checksum = PacketUtil.calculateChecksum(buf, i, icmpLen) + buf[i + 2] = checksum[0] + buf[i + 3] = checksum[1] + return buf + } + + /** IPv6 ICMP packet (next header 58): 40-byte IPv6 header + the same ICMP layout. */ + fun icmpv6Packet( + srcIp: String, dstIp: String, type: Int, code: Int, + identifier: Int, seq: Int, payload: ByteArray = ByteArray(0), + ): ByteArray { + val icmpLen = 8 + payload.size + val total = 40 + icmpLen + val buf = ByteArray(total) + writeIpv6Header(buf, icmpLen, 58, ipv6(srcIp), ipv6(dstIp)) + + val i = 40 + buf[i] = type.toByte() + buf[i + 1] = code.toByte() + buf.putShort(i + 4, identifier) + buf.putShort(i + 6, seq) + System.arraycopy(payload, 0, buf, i + 8, payload.size) + + val checksum = PacketUtil.calculateChecksum(buf, i, icmpLen) + buf[i + 2] = checksum[0] + buf[i + 3] = checksum[1] + return buf + } + + private fun writeIpv6Header( + buf: ByteArray, payloadLength: Int, nextHeader: Int, src: ByteArray, dst: ByteArray, + ) { + buf[0] = 0x60.toByte() // version 6 + buf.putShort(4, payloadLength) + buf[6] = nextHeader.toByte() + buf[7] = 64 // hop limit + System.arraycopy(src, 0, buf, 8, 16) + System.arraycopy(dst, 0, buf, 24, 16) + } } diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/UdpConcurrentFlowsTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/UdpConcurrentFlowsTest.kt new file mode 100644 index 0000000..800ee32 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/UdpConcurrentFlowsTest.kt @@ -0,0 +1,150 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetAddress + +/** + * Connection-tracking and forwarding checks for UDP under concurrency: many simultaneous + * flows demultiplex back to the right client, the same source port to different destinations + * is tracked as separate connections, and a burst of datagrams on one connection keeps its + * message boundaries and ordering. Egress counters accumulate per connection. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class UdpConcurrentFlowsTest { + + private lateinit var harness: ForwardingTestHarness + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + harness.close() + } + + @Test + fun `many concurrent udp flows each get their reply back on the originating port`() { + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + val ports = (41000 until 41008).toList() + try { + ports.forEach { port -> harness.feed(udp(port, peerPort, "msg-$port")) } + + // The peer sees one source socket per flow; echo each payload back. + repeat(ports.size) { + val rx = DatagramPacket(ByteArray(128), 128) + peer.receive(rx) + peer.send(DatagramPacket(rx.data, rx.length, rx.socketAddress)) + } + + // Each reply must return on the client port that originated it. + val byPort = buildMap { + repeat(ports.size) { + val (_, udp, payload) = harness.parseUdp(harness.awaitTunPacket()) + put(udp.destinationPort, String(payload)) + } + } + assertThat(byPort).isEqualTo(ports.associateWith { "msg-$it" }) + + ports.forEach { assertThat(harness.sessionByKey(udpKey(it, peerPort))).isNotNull() } + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(ports.size.toLong()) + } finally { + peer.close() + } + } + + @Test + fun `the same source port to different destinations is tracked as separate connections`() { + val peerA = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerB = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val clientPort = 41100 + try { + harness.feed(udp(clientPort, peerA.localPort, "to-A")) + harness.feed(udp(clientPort, peerB.localPort, "to-B")) + + echoOnce(peerA) + echoOnce(peerB) + + // Both replies share the client port, so distinguish them by source (peer) port. + val bySource = buildMap { + repeat(2) { + val (_, udp, payload) = harness.parseUdp(harness.awaitTunPacket()) + assertThat(udp.destinationPort).isEqualTo(clientPort) + put(udp.sourcePort, String(payload)) + } + } + assertThat(bySource).isEqualTo(mapOf(peerA.localPort to "to-A", peerB.localPort to "to-B")) + + // Two independent sessions, keyed by the full tuple. + assertThat(harness.sessionByKey(udpKey(clientPort, peerA.localPort))).isNotNull() + assertThat(harness.sessionByKey(udpKey(clientPort, peerB.localPort))).isNotNull() + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(2) + } finally { + peerA.close() + peerB.close() + } + } + + @Test + fun `a burst of datagrams on one flow preserves boundaries and ordering`() { + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + val clientPort = 41200 + // Varied sizes (incl. near-MTU), each tagged by index in its first byte. + val sizes = listOf(1, 4, 16, 100, 500, 1400, 7, 64, 1200, 3) + try { + sizes.forEachIndexed { i, size -> + harness.feed(udpBytes(clientPort, peerPort, ByteArray(size) { i.toByte() })) + } + + // The peer must receive these datagrams: same sizes, same order, unmerged. + sizes.forEachIndexed { i, size -> + val rx = DatagramPacket(ByteArray(2048), 2048) + peer.receive(rx) + assertThat(rx.length).isEqualTo(size) + assertThat(rx.data[0].toInt()).isEqualTo(i) // datagram i not merged with i+1 + } + + val session = harness.await { harness.sessionByKey(udpKey(clientPort, peerPort)) } + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(1) + harness.await { session.flow.takeIf { it.packetCountEgress >= sizes.size } } + } finally { + peer.close() + } + } + + // --- helpers ------------------------------------------------------------- + + private fun echoOnce(peer: DatagramSocket) { + val rx = DatagramPacket(ByteArray(128), 128) + peer.receive(rx) + peer.send(DatagramPacket(rx.data, rx.length, rx.socketAddress)) + } + + private fun udp(clientPort: Int, peerPort: Int, payload: String): ByteArray = + udpBytes(clientPort, peerPort, payload.toByteArray()) + + private fun udpBytes(clientPort: Int, peerPort: Int, payload: ByteArray): ByteArray = + TestPackets.udpPacket(clientIp, clientPort, peerIp, peerPort, payload) + + private fun udpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.UDP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) +}