diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 41433ad..96fb757 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -22,6 +22,9 @@ android { versionName = "0.95" testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + // Run each instrumented test in its own process with cleared app data/permissions, so that + // permission-state changes in one test can't kill the shared process during another. + testInstrumentationRunnerArguments["clearPackageData"] = "true" vectorDrawables { useSupportLibrary = true } @@ -77,6 +80,8 @@ android { buildConfig = true } testOptions { + // Isolate each instrumented test in its own process; pairs with clearPackageData above. + execution = "ANDROIDX_TEST_ORCHESTRATOR" unitTests { isIncludeAndroidResources = true // required by Robolectric isReturnDefaultValues = true // android.* stubs return defaults instead of throwing @@ -139,6 +144,10 @@ dependencies { androidTestImplementation(platform(libs.androidx.compose.bom)) androidTestImplementation(libs.androidx.ui.test.junit4) androidTestImplementation(libs.androidx.test.core) + androidTestImplementation(libs.androidx.test.rules) + androidTestImplementation(libs.androidx.test.uiautomator) + androidTestUtil(libs.androidx.test.orchestrator) + androidTestUtil(libs.androidx.test.services) debugImplementation(libs.androidx.ui.tooling) debugImplementation(libs.androidx.ui.test.manifest) implementation(libs.androidx.hilt.navigation.compose) diff --git a/app/src/androidTest/java/org/distrinet/lanshield/ExampleInstrumentedTest.kt b/app/src/androidTest/java/org/distrinet/lanshield/ExampleInstrumentedTest.kt index 378e810..d60aac1 100644 --- a/app/src/androidTest/java/org/distrinet/lanshield/ExampleInstrumentedTest.kt +++ b/app/src/androidTest/java/org/distrinet/lanshield/ExampleInstrumentedTest.kt @@ -17,8 +17,9 @@ import org.junit.Assert.* class ExampleInstrumentedTest { @Test fun useAppContext() { - // Context of the app under test. + // Context of the app under test. Debug builds carry a ".debug" applicationIdSuffix, so match + // the base package name as a prefix rather than exactly. val appContext = InstrumentationRegistry.getInstrumentation().targetContext - assertEquals("org.distrinet.lanshield", appContext.packageName) + assertTrue(appContext.packageName.startsWith("org.distrinet.lanshield")) } } \ No newline at end of file diff --git a/app/src/androidTest/java/org/distrinet/lanshield/NotificationPermissionGateTest.kt b/app/src/androidTest/java/org/distrinet/lanshield/NotificationPermissionGateTest.kt new file mode 100644 index 0000000..7e0b2c8 --- /dev/null +++ b/app/src/androidTest/java/org/distrinet/lanshield/NotificationPermissionGateTest.kt @@ -0,0 +1,163 @@ +package org.distrinet.lanshield + +import android.content.Context +import android.content.Intent +import androidx.lifecycle.MutableLiveData +import androidx.lifecycle.Observer +import androidx.test.core.app.ActivityScenario +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import androidx.test.uiautomator.By +import androidx.test.uiautomator.UiDevice +import androidx.test.uiautomator.Until +import dagger.hilt.android.EntryPointAccessors +import org.distrinet.lanshield.vpnservice.VPNService +import org.junit.After +import org.junit.Assert.assertNotEquals +import org.junit.Assert.assertTrue +import org.junit.Assume.assumeTrue +import org.junit.Before +import org.junit.FixMethodOrder +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.MethodSorters +import java.io.FileInputStream +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.regex.Pattern + +/** + * End-to-end test of the notification-permission gate on enable (MainActivity.startVPNService): with + * POST_NOTIFICATIONS denied, requesting enable must drive the real system permission dialog, and the + * VPN must start only if the user grants it. + * + * The enable request is posted through the same VPN_SERVICE_ACTION.START_VPN signal the Overview + * switch emits, rather than by tapping the Compose switch — the switch's StateFlow does not propagate + * reliably under the Compose test rule, and the switch->signal wiring is not what this test covers. + * From there everything is real: MainActivity's gate, the system permission dialog (driven by + * UiAutomator), and the resulting VPN service state. VPN consent is pre-granted via the ACTIVATE_VPN + * app-op so the only dialog in play is the notification one. Where the dialog can't be driven the + * test self-skips rather than failing, matching the project's choice to keep dialog automation out + * of mandatory CI. + */ +@RunWith(AndroidJUnit4::class) +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +class NotificationPermissionGateTest { + + private val context = ApplicationProvider.getApplicationContext() + private val device = UiDevice.getInstance(InstrumentationRegistry.getInstrumentation()) + private val entryPoint = + EntryPointAccessors.fromApplication(context, VpnStatusEntryPoint::class.java) + private val status: MutableLiveData = entryPoint.vpnServiceStatus() + + @Before + fun setUp() { + // The orchestrator's clearPackageData runs each test in a fresh process with app data and + // permissions reset, so POST_NOTIFICATIONS starts denied-but-askable (the dialog appears). + // Pre-authorize the VPN so consent is not the dialog under test. + shell("appops set ${context.packageName} ACTIVATE_VPN allow") + } + + @After + fun tearDown() { + try { + // Use startService (not startForegroundService) for STOP, as production does: it carries + // no "must call startForeground" promise, so the stop path won't crash the process. + context.startService( + Intent(context, VPNService::class.java).apply { action = VPNService.STOP_VPN_SERVICE } + ) + } catch (_: Exception) { + } + } + + @Test + fun test1_enableSucceeds_whenNotificationGrantedViaDialog() { + ActivityScenario.launch(MainActivity::class.java).use { + requestEnable() + assumeTrue( + "Notification permission dialog could not be driven on this image", + clickPermissionDialogButton(allow = true) + ) + awaitStatus(VPN_SERVICE_STATUS.ENABLED) + } + } + + @Test + fun test2_enableBlocked_whenNotificationDenied() { + ActivityScenario.launch(MainActivity::class.java).use { + requestEnable() + // The dialog must appear (proving the gate ran); deny it. + assumeTrue( + "Notification permission dialog could not be driven on this image", + clickPermissionDialogButton(allow = false) + ) + // The VPN must never come up without notification permission. + assertStaysDisabled(seconds = 5) + } + } + + /** Posts the same enable signal the Overview switch emits; MainActivity observes it and gates. */ + private fun requestEnable() { + runOnMain { entryPoint.vpnServiceActionRequest().value = VPN_SERVICE_ACTION.START_VPN } + } + + /** Returns true if a button was found and clicked. */ + private fun clickPermissionDialogButton(allow: Boolean): Boolean { + // Match the resource-id by suffix so it works whether the dialog is served by + // com.android.permissioncontroller or com.google.android.permissioncontroller. + val resPattern = if (allow) { + Pattern.compile(".*:id/permission_allow_button") + } else { + Pattern.compile(".*:id/permission_deny_button") + } + var button = device.wait(Until.findObject(By.res(resPattern)), 5_000) + if (button == null) { + // Fallback by label; '.' matches either a straight or curly apostrophe in "Don't". + val text = if (allow) { + Pattern.compile("allow", Pattern.CASE_INSENSITIVE) + } else { + Pattern.compile("(don.?t allow|deny)", Pattern.CASE_INSENSITIVE) + } + button = device.wait(Until.findObject(By.text(text)), 3_000) + } + button?.click() + return button != null + } + + private fun awaitStatus(expected: VPN_SERVICE_STATUS, timeoutSeconds: Long = 15) { + val latch = CountDownLatch(1) + val observer = Observer { if (it == expected) latch.countDown() } + runOnMain { status.observeForever(observer) } + try { + assertTrue( + "VPN status did not reach $expected within ${timeoutSeconds}s (was ${status.value})", + latch.await(timeoutSeconds, TimeUnit.SECONDS) + ) + } finally { + runOnMain { status.removeObserver(observer) } + } + } + + private fun assertStaysDisabled(seconds: Long) { + val deadline = System.currentTimeMillis() + seconds * 1000 + while (System.currentTimeMillis() < deadline) { + assertNotEquals( + "VPN started despite notification permission being denied", + VPN_SERVICE_STATUS.ENABLED, + status.value + ) + Thread.sleep(250) + } + } + + private fun runOnMain(block: () -> Unit) { + InstrumentationRegistry.getInstrumentation().runOnMainSync(block) + } + + private fun shell(command: String) { + val pfd = InstrumentationRegistry.getInstrumentation().uiAutomation + .executeShellCommand(command) + FileInputStream(pfd.fileDescriptor).use { it.readBytes() } + } +} diff --git a/app/src/androidTest/java/org/distrinet/lanshield/vpnservice/VpnServiceStartCommandTest.kt b/app/src/androidTest/java/org/distrinet/lanshield/vpnservice/VpnServiceStartCommandTest.kt new file mode 100644 index 0000000..564cd89 --- /dev/null +++ b/app/src/androidTest/java/org/distrinet/lanshield/vpnservice/VpnServiceStartCommandTest.kt @@ -0,0 +1,114 @@ +package org.distrinet.lanshield.vpnservice + +import android.content.Context +import android.content.Intent +import android.net.VpnService +import androidx.lifecycle.MutableLiveData +import androidx.lifecycle.Observer +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import dagger.hilt.android.EntryPointAccessors +import org.distrinet.lanshield.VPN_SERVICE_STATUS +import org.distrinet.lanshield.VpnStatusEntryPoint +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assume.assumeTrue +import org.junit.Test +import org.junit.runner.RunWith +import java.io.FileInputStream +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +/** + * On-device test for [VPNService] start/stop dispatch — the fix that keeps the VPN alive across + * OS-initiated restarts. + * + * Drives the real VpnService through start → restart → stop and asserts the VPN comes up, survives + * an action-less restart (the closest a test can get to the system's null re-delivery after a kill + * or via always-on VPN), and is torn down only on an explicit stop. Establishing a tunnel needs VPN + * consent, which it grants non-interactively via the ACTIVATE_VPN app-op (works on a debuggable + * emulator image). Where that grant is not permitted the test self-skips via [assumeTrue] instead + * of failing, matching the project's choice to keep VPN-consent automation out of mandatory CI. + */ +@RunWith(AndroidJUnit4::class) +class VpnServiceStartCommandTest { + + @Test + fun endToEnd_restartReEstablishesVpn() { + val context = ApplicationProvider.getApplicationContext() + grantVpnConsent(context.packageName) + + // If consent could not be granted on this image, there is nothing meaningful to assert. + assumeTrue("VPN consent unavailable on this device/image", VpnService.prepare(context) == null) + + val status = EntryPointAccessors + .fromApplication(context, VpnStatusEntryPoint::class.java) + .vpnServiceStatus() + + try { + // 1. A fresh start (no action) must bring the VPN up. + context.startForegroundService(Intent(context, VPNService::class.java)) + awaitStatus(status, VPN_SERVICE_STATUS.ENABLED) + + // 2. Simulate the OS restarting the still-running service with a bare, action-less + // intent (closest a test can get to the system's null re-delivery). It must remain + // ENABLED rather than being torn down. + context.startForegroundService(Intent(context, VPNService::class.java)) + Thread.sleep(500) + assertEquals(VPN_SERVICE_STATUS.ENABLED, status.value) + + // 3. An explicit stop must actually stop it (and stopSelf so it is not resurrected). + // Use startService for STOP, as production does: startForegroundService would create a + // "must call startForeground" promise that the stop path never fulfills (it stops), + // crashing the process. + context.startService( + Intent(context, VPNService::class.java).apply { action = VPNService.STOP_VPN_SERVICE } + ) + awaitStatus(status, VPN_SERVICE_STATUS.DISABLED) + } finally { + context.startService( + Intent(context, VPNService::class.java).apply { action = VPNService.STOP_VPN_SERVICE } + ) + } + } + + /** Blocks until [status] reaches [expected], asserting it does so within the timeout. */ + private fun awaitStatus( + status: MutableLiveData, + expected: VPN_SERVICE_STATUS, + timeoutSeconds: Long = 10 + ) { + val latch = CountDownLatch(1) + val observer = object : Observer { + override fun onChanged(value: VPN_SERVICE_STATUS) { + if (value == expected) latch.countDown() + } + } + val instrumentation = InstrumentationRegistry.getInstrumentation() + instrumentation.runOnMainSync { status.observeForever(observer) } + try { + assertTrue( + "VPN status did not reach $expected within ${timeoutSeconds}s (was ${status.value})", + latch.await(timeoutSeconds, TimeUnit.SECONDS) + ) + } finally { + instrumentation.runOnMainSync { status.removeObserver(observer) } + } + } + + /** + * Pre-authorizes this package as a VPN by flipping the ACTIVATE_VPN app-op, so + * [VpnService.prepare] returns null and no consent dialog is needed. + */ + private fun grantVpnConsent(packageName: String) { + executeShellCommand("appops set $packageName ACTIVATE_VPN allow") + } + + private fun executeShellCommand(command: String) { + val automation = InstrumentationRegistry.getInstrumentation().uiAutomation + val pfd = automation.executeShellCommand(command) + // Drain so the command completes before we return. + FileInputStream(pfd.fileDescriptor).use { it.readBytes() } + } +} diff --git a/app/src/main/java/org/distrinet/lanshield/LANShieldApplication.kt b/app/src/main/java/org/distrinet/lanshield/LANShieldApplication.kt index a023050..5d6a9cf 100644 --- a/app/src/main/java/org/distrinet/lanshield/LANShieldApplication.kt +++ b/app/src/main/java/org/distrinet/lanshield/LANShieldApplication.kt @@ -18,6 +18,7 @@ import androidx.lifecycle.MutableLiveData import androidx.work.Configuration import dagger.Module import dagger.Provides +import dagger.hilt.EntryPoint import dagger.hilt.InstallIn import dagger.hilt.android.HiltAndroidApp import dagger.hilt.android.qualifiers.ApplicationContext @@ -235,6 +236,16 @@ class LANShieldApplication : Application(), Configuration.Provider { } +/** + * Exposes the VPN status for testing. + */ +@EntryPoint +@InstallIn(SingletonComponent::class) +interface VpnStatusEntryPoint { + fun vpnServiceStatus(): MutableLiveData + fun vpnServiceActionRequest(): MutableLiveData +} + @Module @InstallIn(SingletonComponent::class) object AppModule { diff --git a/app/src/main/java/org/distrinet/lanshield/MainActivity.kt b/app/src/main/java/org/distrinet/lanshield/MainActivity.kt index c658e3b..a2eaa58 100644 --- a/app/src/main/java/org/distrinet/lanshield/MainActivity.kt +++ b/app/src/main/java/org/distrinet/lanshield/MainActivity.kt @@ -1,15 +1,22 @@ package org.distrinet.lanshield +import android.Manifest +import android.content.ActivityNotFoundException import android.content.Intent +import android.content.pm.PackageManager import android.graphics.Color import android.net.VpnService +import android.os.Build import android.os.Bundle +import android.provider.Settings +import android.widget.Toast import androidx.activity.ComponentActivity import androidx.activity.SystemBarStyle import androidx.activity.compose.setContent import androidx.activity.enableEdgeToEdge import androidx.activity.result.ActivityResultLauncher import androidx.activity.result.contract.ActivityResultContracts +import androidx.core.content.ContextCompat import androidx.compose.foundation.isSystemInDarkTheme import androidx.compose.runtime.DisposableEffect import androidx.datastore.core.DataStore @@ -50,6 +57,7 @@ class MainActivity : ComponentActivity() { @Inject lateinit var vpnServiceActionRequest: MutableLiveData private lateinit var vpnPermissionLauncher: ActivityResultLauncher + private lateinit var notificationPermissionLauncher: ActivityResultLauncher @Inject lateinit var dataStore: DataStore @@ -65,6 +73,16 @@ class MainActivity : ComponentActivity() { } } + notificationPermissionLauncher = + registerForActivityResult(ActivityResultContracts.RequestPermission()) { granted -> + if (granted) { + // Permission obtained; continue the enable flow that was paused to ask for it. + proceedStartVPNService() + } else { + onNotificationPermissionDenied() + } + } + vpnServiceActionRequest.observe(this) { when (it) { VPN_SERVICE_ACTION.START_VPN -> startVPNService() @@ -198,6 +216,17 @@ class MainActivity : ComponentActivity() { private fun startVPNService() { + // Notification permission can be auto-revoked while the app is unused (and is only requested + // during onboarding otherwise), yet the foreground-service banner and the LAN-traffic + // allow/block prompts depend on it. Re-check on every enable and refuse to start without it. + if (!hasNotificationPermission()) { + notificationPermissionLauncher.launch(Manifest.permission.POST_NOTIFICATIONS) + return + } + proceedStartVPNService() + } + + private fun proceedStartVPNService() { if (!hasVPNConsent()) { askVPNConsent() } else { @@ -206,7 +235,36 @@ class MainActivity : ComponentActivity() { } catch (_: SecurityException) { } } + } + + private fun hasNotificationPermission(): Boolean { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) return true + return ContextCompat.checkSelfPermission( + this, + Manifest.permission.POST_NOTIFICATIONS + ) == PackageManager.PERMISSION_GRANTED + } + + private fun onNotificationPermissionDenied() { + // The VPN is intentionally not started; the switch stays off because it tracks the service + // status. Tell the user why, and route to settings when the permission is permanently denied + // (the request dialog no longer appears), so they aren't left with a switch that does nothing. + Toast.makeText(this, R.string.notification_permission_required, Toast.LENGTH_LONG).show() + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU && + !shouldShowRequestPermissionRationale(Manifest.permission.POST_NOTIFICATIONS) + ) { + openAppNotificationSettings() + } + } + private fun openAppNotificationSettings() { + val intent = Intent(Settings.ACTION_APP_NOTIFICATION_SETTINGS).apply { + putExtra(Settings.EXTRA_APP_PACKAGE, packageName) + } + try { + startActivity(intent) + } catch (_: ActivityNotFoundException) { + } } private fun stopVPNService() { diff --git a/app/src/main/java/org/distrinet/lanshield/backendsync/OpenPortsFinder.kt b/app/src/main/java/org/distrinet/lanshield/backendsync/OpenPortsFinder.kt index b6c739c..a014809 100644 --- a/app/src/main/java/org/distrinet/lanshield/backendsync/OpenPortsFinder.kt +++ b/app/src/main/java/org/distrinet/lanshield/backendsync/OpenPortsFinder.kt @@ -3,6 +3,7 @@ package org.distrinet.lanshield.backendsync import android.content.pm.PackageManager import android.net.ConnectivityManager import android.os.Process.INVALID_UID +import android.util.Log import android.system.OsConstants.IPPROTO_TCP import android.system.OsConstants.IPPROTO_UDP import kotlinx.coroutines.Dispatchers @@ -20,29 +21,31 @@ suspend fun findOpenPorts( pm: PackageManager, connectivityManager: ConnectivityManager ): List = withContext(Dispatchers.Default) { - coroutineScope { - val wildcardIpv4 = InetSocketAddress("0.0.0.0", 0) - val wildcardIpv6 = InetSocketAddress("::", 0) + val openPortsByUid = ConcurrentHashMap() + try { + coroutineScope { + val wildcardIpv4 = InetSocketAddress("0.0.0.0", 0) + val wildcardIpv6 = InetSocketAddress("::", 0) - val openPortsByUid = ConcurrentHashMap() - - val jobs = (1..65535).map { port -> - launch { - checkPort( - port, - connectivityManager, - pm, - openPortsByUid, - wildcardIpv4, - wildcardIpv6 - ) + val jobs = (1..65535).map { port -> + launch { + checkPort( + port, + connectivityManager, + pm, + openPortsByUid, + wildcardIpv4, + wildcardIpv6 + ) + } } - } - jobs.joinAll() - - openPortsByUid.values.sorted() + jobs.joinAll() + } + } catch (e: SecurityException) { + Log.w("OpenPortsFinder", "No longer the active VPN; aborting open-port scan", e) } + openPortsByUid.values.sorted() } diff --git a/app/src/main/java/org/distrinet/lanshield/ui/openports/OpenPortsViewModel.kt b/app/src/main/java/org/distrinet/lanshield/ui/openports/OpenPortsViewModel.kt index 26858c3..9438026 100644 --- a/app/src/main/java/org/distrinet/lanshield/ui/openports/OpenPortsViewModel.kt +++ b/app/src/main/java/org/distrinet/lanshield/ui/openports/OpenPortsViewModel.kt @@ -34,6 +34,11 @@ class OpenPortsViewModel @Inject constructor( fun refreshOpenPorts(context: Context) { + if (vpnServiceStatus.value != VPN_SERVICE_STATUS.ENABLED) { + _appsWithPorts.value = emptyList() + _isRefreshing.value = false + return + } _isRefreshing.value = true viewModelScope.launch(Dispatchers.Default) { val openPorts = findOpenPorts( diff --git a/app/src/main/java/org/distrinet/lanshield/vpnservice/LANShieldNotificationManager.kt b/app/src/main/java/org/distrinet/lanshield/vpnservice/LANShieldNotificationManager.kt index 3d7edf3..9a4fc26 100644 --- a/app/src/main/java/org/distrinet/lanshield/vpnservice/LANShieldNotificationManager.kt +++ b/app/src/main/java/org/distrinet/lanshield/vpnservice/LANShieldNotificationManager.kt @@ -181,6 +181,29 @@ class LANShieldNotificationManager(private val context: Context) { notificationManager.notify(activeNotification.notificationId, notification) } + fun postServiceErrorNotification(title: String, text: String) { + val openAppIntent = Intent(context, MainActivity::class.java).apply { + flags = Intent.FLAG_ACTIVITY_NEW_TASK or Intent.FLAG_ACTIVITY_CLEAR_TASK + } + val openAppPendingIntent = PendingIntent.getActivity( + context, + getNewIntentRequestCode(), + openAppIntent, + PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE + ) + + val notification = NotificationCompat.Builder(context, SERVICE_NOTIFICATION_CHANNEL_ID) + .setSmallIcon(R.mipmap.logo_foreground) + .setContentTitle(title) + .setContentText(text) + .setStyle(NotificationCompat.BigTextStyle().bigText(text)) + .setVisibility(NotificationCompat.VISIBILITY_PUBLIC) + .setAutoCancel(true) + .setContentIntent(openAppPendingIntent) + .build() + notificationManager.notify(getNewNotificationId(), notification) + } + fun createNotificationChannels() { val serviceChannel = NotificationChannel( diff --git a/app/src/main/java/org/distrinet/lanshield/vpnservice/VPNService.kt b/app/src/main/java/org/distrinet/lanshield/vpnservice/VPNService.kt index f6ca56e..b0698ef 100644 --- a/app/src/main/java/org/distrinet/lanshield/vpnservice/VPNService.kt +++ b/app/src/main/java/org/distrinet/lanshield/vpnservice/VPNService.kt @@ -31,6 +31,7 @@ import org.distrinet.lanshield.Policy import org.distrinet.lanshield.R import org.distrinet.lanshield.SERVICE_NOTIFICATION_CHANNEL_ID import org.distrinet.lanshield.SYSTEM_APPS_POLICY_KEY +import org.distrinet.lanshield.crashreport.crashReporter import org.distrinet.lanshield.TAG import org.distrinet.lanshield.VPN_ALWAYS_ON_STATUS import org.distrinet.lanshield.VPN_SERVICE_STATUS @@ -42,6 +43,7 @@ import tech.httptoolkit.android.vpn.socket.IProtectSocket import tech.httptoolkit.android.vpn.socket.SocketProtector import java.net.InetAddress import java.net.NetworkInterface +import java.net.SocketException import javax.inject.Inject /* The IP address of the virtual network interface */ @@ -122,31 +124,29 @@ class VPNService : VpnService(), IProtectSocket { updateAlwaysOnStatus() - intent?.let { - when (it.action) { - STOP_VPN_SERVICE -> { - if (isVPNRunning()) { - stopVPNThread() - stopForeground(STOP_FOREGROUND_REMOVE) - } - } - - else -> { - if (!isVPNRunning()) { - LANShieldNotificationManager(this).createNotificationChannels() - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { - startForeground( - 1, - createNotification(), - FOREGROUND_SERVICE_TYPE_SYSTEM_EXEMPTED - ) - } else { - startForeground(1, createNotification()) - } - startVPNThread() - } - } + // Only the explicit STOP action stops the VPN. Everything else — including a null intent, + // which the OS re-delivers when START_STICKY restarts the process after a kill, and when + // Android's always-on VPN restarts us — is treated as a start request, so the tunnel is + // always re-established instead of silently staying down with the UI switch showing DISABLED. + if (intent?.action == STOP_VPN_SERVICE) { + if (isVPNRunning()) { + stopVPNThread() + stopForeground(STOP_FOREGROUND_REMOVE) } + // Fully tear down so START_STICKY won't resurrect a VPN the user explicitly stopped. + stopSelf() + } else if (!isVPNRunning()) { + LANShieldNotificationManager(this).createNotificationChannels() + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + startForeground( + 1, + createNotification(), + FOREGROUND_SERVICE_TYPE_SYSTEM_EXEMPTED + ) + } else { + startForeground(1, createNotification()) + } + startVPNThread() } // Return the appropriate service restart behavior @@ -157,6 +157,8 @@ class VPNService : VpnService(), IProtectSocket { super.onRevoke() stopVPNThread() stopForeground(STOP_FOREGROUND_REMOVE) + // Permission was revoked (e.g. another VPN took over); don't let START_STICKY restart us. + stopSelf() } override fun onDestroy() { @@ -305,20 +307,32 @@ class VPNService : VpnService(), IProtectSocket { // that are always installed. See how other VPNs do this and for starting points see: // - https://stackoverflow.com/questions/6169059/android-event-for-internet-connectivity-state-change // - https://medium.com/@veniamin.vynohradov/monitoring-internet-connection-state-in-android-da7ad915b5e5 - for (networkInterface in NetworkInterface.getNetworkInterfaces()) { - if (networkInterface.isLoopback) continue - - for (address in networkInterface.interfaceAddresses) { - if (address.address.isAnyLocalAddress or - address.address.isLinkLocalAddress or - address.address.isSiteLocalAddress - ) continue - val networkAddress = getNetworkAddress(address.address, address.networkPrefixLength) - builder.addRoute(networkAddress, address.networkPrefixLength.toInt()) - Log.d( - TAG, - "Also monitoring " + networkAddress.toString() + "/" + address.networkPrefixLength.toString() - ) + val interfaces = try { + NetworkInterface.getNetworkInterfaces() ?: return + } catch (e: SocketException) { + Log.w(TAG, "Could not enumerate network interfaces", e) + return + } + + for (networkInterface in interfaces) { + try { + if (networkInterface.isLoopback) continue + + for (address in networkInterface.interfaceAddresses) { + if (address.address.isAnyLocalAddress or + address.address.isLinkLocalAddress or + address.address.isSiteLocalAddress + ) continue + val networkAddress = getNetworkAddress(address.address, address.networkPrefixLength) + builder.addRoute(networkAddress, address.networkPrefixLength.toInt()) + Log.d( + TAG, + "Also monitoring " + networkAddress.toString() + "/" + address.networkPrefixLength.toString() + ) + } + } catch (e: SocketException) { + // Interface disappeared between enumeration and query (ENODEV) — skip it. + Log.w(TAG, "Skipping interface ${networkInterface.name}: ${e.message}") } } } @@ -340,8 +354,22 @@ class VPNService : VpnService(), IProtectSocket { .setMtu(MAX_PACKET_LEN) .setMetered(false) - // establish() returns null if we no longer have permissions to establish the VPN somehow - val vpnInterface = builder.establish() ?: return + val vpnInterface = try { + builder.establish() + } catch (e: IllegalStateException) { + Log.w(TAG, "Could not establish VPN interface", e) + crashReporter.recordException(e) + null + } + if (vpnInterface == null) { + stopForeground(STOP_FOREGROUND_REMOVE) + setVPNRunning(false) + vpnNotificationManager.postServiceErrorNotification( + getString(R.string.lanshield_start_failed_title), + getString(R.string.lanshield_start_failed_text) + ) + return + } this.vpnInterface = vpnInterface SocketProtector.getInstance().setProtector(this) diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/Session.java b/app/src/main/java/tech/httptoolkit/android/vpn/Session.java index 0428e55..af5ea16 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/Session.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/Session.java @@ -32,6 +32,7 @@ import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.spi.AbstractSelectableChannel; +import java.util.ArrayDeque; /** * store information about a socket connection from a VPN client. @@ -73,8 +74,12 @@ public class Session { //receiving buffer for storing data from remote host private final ByteArrayOutputStream receivingStream; - //sending buffer for storing data from vpn client to be send to destination host + //sending buffer for storing data from vpn client to be send to destination host (TCP only) private final ByteArrayOutputStream sendingStream; + + //queue of discrete datagrams to be sent to the destination host (UDP only). UDP must + //preserve datagram boundaries, so unlike TCP it cannot use a flat byte stream. + private final ArrayDeque sendingDatagrams = new ArrayDeque<>(); private boolean hasReceivedLastSegment = false; @@ -171,13 +176,21 @@ public boolean hasReceivedData(){ } /** - * set data to be sent to destination server + * set data to be sent to destination server. + * For UDP each call is queued as a discrete datagram + * For TCP the bytes are appended to the send stream. * @param data Data to be sent - * @return boolean Success or not + * @return int number of bytes accepted */ public synchronized int setSendingData(ByteBuffer data) { final int remaining = data.remaining(); - sendingStream.write(data.array(), data.position(), data.remaining()); + if (protocol == SessionProtocol.UDP) { + byte[] datagram = new byte[remaining]; + System.arraycopy(data.array(), data.position(), datagram, 0, remaining); + sendingDatagrams.addLast(datagram); + } else { + sendingStream.write(data.array(), data.position(), remaining); + } return remaining; } @@ -186,7 +199,7 @@ int getSendingDataSize(){ } /** - * dequeue data for sending to server + * dequeue all stream data for sending to the server (TCP). * @return byte[] */ public synchronized byte[] getSendingData(){ @@ -194,12 +207,30 @@ public synchronized byte[] getSendingData(){ sendingStream.reset(); return data; } + + /** + * dequeue the next datagram for sending to the server (UDP), or null if none remain. + * @return byte[] + */ + public synchronized byte[] pollSendingDatagram(){ + return sendingDatagrams.pollFirst(); + } + + /** + * return a datagram to the head of the queue when it could not be written yet (UDP). + */ + public synchronized void requeueSendingDatagram(byte[] datagram){ + sendingDatagrams.addFirst(datagram); + } + /** * buffer contains data for sending to destination server * @return boolean */ - public boolean hasDataToSend(){ - return sendingStream.size() > 0; + public synchronized boolean hasDataToSend(){ + return protocol == SessionProtocol.UDP + ? !sendingDatagrams.isEmpty() + : sendingStream.size() > 0; } public SessionProtocol getProtocol() { diff --git a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelWriter.java b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelWriter.java index 6c4657b..6fdabc6 100644 --- a/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelWriter.java +++ b/app/src/main/java/tech/httptoolkit/android/vpn/socket/SocketChannelWriter.java @@ -80,7 +80,7 @@ public long write(@NonNull Session session) { private long writeUDP(Session session) { long amountBytes = 0; try { - amountBytes = writePendingData(session); + amountBytes = writePendingUDPData(session); Date dt = new Date(); session.connectionStartTime = dt.getTime(); }catch(NotYetConnectedException ex2){ @@ -117,11 +117,12 @@ private long writeTCP(Session session) { return amountBytes; } + /** TCP: a byte stream, so buffered bytes are concatenated and written as-is. */ private long writePendingData(Session session) throws IOException { if (!session.hasDataToSend()) return 0; long totalBytesWritten = 0; - AbstractSelectableChannel channel = session.getChannel(); + SocketChannel channel = (SocketChannel) session.getChannel(); byte[] data = session.getSendingData(); ByteBuffer buffer = ByteBuffer.allocate(data.length); @@ -129,9 +130,7 @@ private long writePendingData(Session session) throws IOException { buffer.flip(); while (buffer.hasRemaining()) { - int bytesWritten = channel instanceof SocketChannel - ? ((SocketChannel) channel).write(buffer) - : ((DatagramChannel) channel).write(buffer); + int bytesWritten = channel.write(buffer); if (bytesWritten == 0) { break; @@ -149,7 +148,7 @@ private long writePendingData(Session session) throws IOException { // Subscribe to WRITE events, so we know when this is ready to resume. session.subscribeKey(SelectionKey.OP_WRITE); } else { - // All done, all good -> wait until the next TCP PSH / UDP packet + // All done, all good -> wait until the next TCP PSH packet session.setDataForSendingReady(false); // We don't need to know about WRITE events any more, we've written all our data. @@ -158,4 +157,39 @@ private long writePendingData(Session session) throws IOException { } return totalBytesWritten; } + + /** + * UDP: a datagram protocol, so each queued datagram must be written with its own + * channel.write() to preserve message boundaries. We send one datagram per write cycle + * and resubscribe to OP_WRITE while more remain, mirroring the TCP backpressure pattern. + */ + private long writePendingUDPData(Session session) throws IOException { + byte[] datagram = session.pollSendingDatagram(); + if (datagram == null) { + session.setDataForSendingReady(false); + session.unsubscribeKey(SelectionKey.OP_WRITE); + return 0; + } + + DatagramChannel channel = (DatagramChannel) session.getChannel(); + // A connected non-blocking DatagramChannel writes the whole datagram or nothing + // (0 when the send buffer is full); it never sends a partial datagram. + int bytesWritten = channel.write(ByteBuffer.wrap(datagram)); + + if (bytesWritten == 0) { + // Not ready yet: put the datagram back and resume on the next OP_WRITE. + session.requeueSendingDatagram(datagram); + session.subscribeKey(SelectionKey.OP_WRITE); + return 0; + } + + if (session.hasDataToSend()) { + // More datagrams queued -> come back for the next one. + session.subscribeKey(SelectionKey.OP_WRITE); + } else { + session.setDataForSendingReady(false); + session.unsubscribeKey(SelectionKey.OP_WRITE); + } + return bytesWritten; + } } diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index 171f847..728b89e 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -58,6 +58,8 @@ LAN traffic from %1$s Here you can see an overview of all apps that have sent or received LAN traffic. In some cases, LANShield is not able to detect which app caused certain packets to be sent; these are categorized under \'Unknown\'. LANShield enabled + Couldn\'t start LANShield + The LAN firewall could not be started. Please try again. LANShield lets you control\nLAN access of other apps LANShield logo Last 24h: %1$s @@ -106,6 +108,7 @@ Unknown VPN notification VPN permission request + LANShield needs notification permission to run. Enable notifications to turn it on. You\'re all set! Export LAN Traffic Advanced diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/ConnectionTrackingForwardingTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/ConnectionTrackingForwardingTest.kt new file mode 100644 index 0000000..a782c19 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/ConnectionTrackingForwardingTest.kt @@ -0,0 +1,219 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import tech.httptoolkit.android.vpn.transport.ip.IPHeader +import tech.httptoolkit.android.vpn.transport.tcp.TCPHeader +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * End-to-end correctness checks on the engine's connection tracking, through the real + * forwarding pipeline against loopback peers. Exercises both directions: + * + * - egress (packet intercepted from a local app): the engine creates/reuses the right + * session per 5-tuple, and demultiplexes concurrent connections without crossing streams; + * - ingress (data coming back from the LAN peer): each reply is routed back to the exact + * client connection that originated it, and traffic for an unknown connection is rejected. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class ConnectionTrackingForwardingTest { + + private lateinit var harness: ForwardingTestHarness + + private val clientIp = "10.0.0.2" + private val peerIp = "127.0.0.1" + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + harness.close() + } + + // --- UDP ----------------------------------------------------------------- + + @Test + fun `concurrent udp connections demultiplex replies back to the originating client`() { + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + try { + // Two client connections to the same peer, distinguished only by source port. + harness.feed(TestPackets.udpPacket(clientIp, 40001, peerIp, peerPort, "one".toByteArray())) + harness.feed(TestPackets.udpPacket(clientIp, 40002, peerIp, peerPort, "two".toByteArray())) + + // The peer sees two distinct source sockets; echo each payload straight back. + repeat(2) { + val rx = DatagramPacket(ByteArray(64), 64) + peer.receive(rx) + peer.send(DatagramPacket(rx.data, rx.length, rx.socketAddress)) + } + + // Collect both TUN replies, then assert each landed on the correct client port: + // "one" must return to 40001 and "two" to 40002 (no cross-talk). + val byClientPort = buildMap { + repeat(2) { + val (_, udp, payload) = harness.parseUdp(harness.awaitTunPacket()) + put(udp.destinationPort, String(payload)) + } + } + assertThat(byClientPort).containsExactly(40001, "one", 40002, "two") + } finally { + peer.close() + } + } + + @Test + fun `repeated udp datagrams reuse one connection and accumulate egress`() { + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + try { + harness.feed(TestPackets.udpPacket(clientIp, 40003, peerIp, peerPort, "p1".toByteArray())) + harness.feed(TestPackets.udpPacket(clientIp, 40003, peerIp, peerPort, "p2".toByteArray())) + + // Both datagrams reach the peer over the same upstream socket: the second reused + // the connection rather than opening a new one. + val first = DatagramPacket(ByteArray(64), 64).also { peer.receive(it) } + val second = DatagramPacket(ByteArray(64), 64).also { peer.receive(it) } + assertThat( + setOf(String(first.data, 0, first.length), String(second.data, 0, second.length)) + ).containsExactly("p1", "p2") + assertThat(second.socketAddress).isEqualTo(first.socketAddress) + + // Exactly one session/flow was created; egress counters accumulate across both. + val session = harness.await { + harness.sessionByKey(udpKey(40003, peerPort)) + } + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(1) + harness.await { session.flow.takeIf { it.packetCountEgress >= 2 } } + } finally { + peer.close() + } + } + + @Test + fun `back-to-back udp datagrams preserve message boundaries`() { + // Regression test for datagram coalescing: two datagrams sent on the same connection + // before the writer drains must NOT be merged into one upstream datagram. + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + try { + harness.feed(TestPackets.udpPacket(clientIp, 40004, peerIp, peerPort, "AAAA".toByteArray())) + harness.feed(TestPackets.udpPacket(clientIp, 40004, peerIp, peerPort, "BBBB".toByteArray())) + + // The peer must receive two distinct 4-byte datagrams, not one merged "AAAABBBB". + val first = DatagramPacket(ByteArray(64), 64).also { peer.receive(it) } + val second = DatagramPacket(ByteArray(64), 64).also { peer.receive(it) } + assertThat(first.length).isEqualTo(4) + assertThat(second.length).isEqualTo(4) + assertThat(listOf(String(first.data, 0, 4), String(second.data, 0, 4))) + .containsExactly("AAAA", "BBBB").inOrder() + } finally { + peer.close() + } + } + + // --- TCP ----------------------------------------------------------------- + + @Test + fun `concurrent tcp connections are tracked independently`() { + val server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + val peerPort = server.localPort + val executor = Executors.newFixedThreadPool(2) + val isn1 = 1000L + val isn2 = 5000L + try { + val accept1 = executor.submit { server.accept() } + val accept2 = executor.submit { server.accept() } + + harness.feed(syn(40001, peerPort, isn1)) + harness.feed(syn(40002, peerPort, isn2)) + + // Each SYN-ACK must be demultiplexed to its own client port and acknowledge that + // connection's ISN — proving the two handshakes are not conflated. + val synAcks = buildMap> { + repeat(2) { + val pkt = harness.awaitTunPacketMatching { + val (_, tcp) = harness.parseTcp(it); tcp.isSYN && tcp.isACK + } + val parsed = harness.parseTcp(pkt) + put(parsed.second.destinationPort, parsed) + } + } + + assertThat(synAcks.keys).containsExactly(40001, 40002) + assertThat(synAcks.getValue(40001).second.ackNumber).isEqualTo(isn1 + 1) + assertThat(synAcks.getValue(40002).second.ackNumber).isEqualTo(isn2 + 1) + assertThat(synAcks.getValue(40001).first.destinationIP.toString()).isEqualTo(clientIp) + assertThat(synAcks.getValue(40002).first.destinationIP.toString()).isEqualTo(clientIp) + + // Two distinct sessions and two recorded flows. + assertThat(harness.sessionByKey(tcpKey(40001, peerPort))).isNotNull() + assertThat(harness.sessionByKey(tcpKey(40002, peerPort))).isNotNull() + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(2) + + accept1.get(3, TimeUnit.SECONDS) + accept2.get(3, TimeUnit.SECONDS) + } finally { + executor.shutdownNow() + server.close() + } + } + + @Test + fun `tcp data for an unknown connection is rejected with RST and creates no session`() { + // An ACK with payload but no preceding SYN has no tracked connection: the engine must + // reject it with a RST rather than silently adopting it or crashing. + harness.feed( + TestPackets.tcpPacket( + clientIp, 40009, peerIp, 9, + seq = 42, ack = 99, flags = TestPackets.ACK, payload = "junk".toByteArray(), + ) + ) + + val rst = harness.awaitTunPacketMatching { harness.parseTcp(it).second.isRST } + val (ip, tcp) = harness.parseTcp(rst) + assertThat(tcp.isRST).isTrue() + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) + assertThat(tcp.destinationPort).isEqualTo(40009) + + assertThat(harness.sessionByKey(tcpKey(40009, 9))).isNull() + assertThat(harness.flowDao.countNotSyncedFlows()).isEqualTo(0) + } + + // --- helpers ------------------------------------------------------------- + + private fun syn(clientPort: Int, peerPort: Int, isn: Long): ByteArray = + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = isn, ack = 0, flags = TestPackets.SYN, mss = 1460, + ) + + private fun udpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.UDP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun tcpKey(clientPort: Int, peerPort: Int): String = Session.getSessionKey( + SessionProtocol.TCP, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) +} diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/SessionManagerUnitTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/SessionManagerUnitTest.kt index 5d9cd37..6edfca0 100644 --- a/app/src/test/java/tech/httptoolkit/android/vpn/SessionManagerUnitTest.kt +++ b/app/src/test/java/tech/httptoolkit/android/vpn/SessionManagerUnitTest.kt @@ -69,6 +69,67 @@ class SessionManagerUnitTest { assertThat(db.FlowDao().countNotSyncedFlows()).isEqualTo(1) } + @Test + fun `recreating the same tcp session is deduplicated and inserts only one flow`() { + // Two SYNs for the same 5-tuple (e.g. a retransmitted SYN) must map to one connection. + val server = ServerSocket(0, 50, InetAddress.getByName("127.0.0.1")) + try { + val port = server.localPort + val first = manager.createNewTCPSession(dstIp, port, srcIp, 50000, 40, "pkg") + val second = manager.createNewTCPSession(dstIp, port, srcIp, 50000, 40, "pkg") + + assertThat(second).isSameInstanceAs(first) + assertThat(db.FlowDao().countNotSyncedFlows()).isEqualTo(1) + } finally { + server.close() + } + } + + @Test + fun `udp and tcp with the same tuple are tracked as separate connections`() { + // Protocol is part of the connection identity: identical addresses/ports over UDP and + // TCP must not collide into one session. + val server = ServerSocket(0, 50, InetAddress.getByName("127.0.0.1")) + try { + val port = server.localPort + val udp = manager.createNewUDPSession( + dstIp, port, srcIp, 50000, 28, "pkg", + ByteBuffer.wrap(TestPackets.udpPacket("10.0.0.2", 50000, "127.0.0.1", port, "x".toByteArray())), + ) + val tcp = manager.createNewTCPSession(dstIp, port, srcIp, 50000, 40, "pkg") + + assertThat(tcp).isNotSameInstanceAs(udp) + assertThat(tcp.sessionKey).isNotEqualTo(udp.sessionKey) + assertThat(manager.getSession(SessionProtocol.UDP, dstIp, port, srcIp, 50000)) + .isSameInstanceAs(udp) + assertThat(manager.getSession(SessionProtocol.TCP, dstIp, port, srcIp, 50000)) + .isSameInstanceAs(tcp) + assertThat(db.FlowDao().countNotSyncedFlows()).isEqualTo(2) + } finally { + server.close() + } + } + + @Test + fun `sessions are keyed by the full 5-tuple`() { + // Same client to the same peer but a different source port, or to a different peer port, + // are distinct connections, each independently retrievable and recorded. + val base = createUdp(srcPort = 50000, dstPort = 9999) + val differentSrcPort = createUdp(srcPort = 50001, dstPort = 9999) + val differentDstPort = createUdp(srcPort = 50000, dstPort = 8888) + + assertThat(differentSrcPort).isNotSameInstanceAs(base) + assertThat(differentDstPort).isNotSameInstanceAs(base) + assertThat(db.FlowDao().countNotSyncedFlows()).isEqualTo(3) + + assertThat(manager.getSession(SessionProtocol.UDP, dstIp, 9999, srcIp, 50000)) + .isSameInstanceAs(base) + assertThat(manager.getSession(SessionProtocol.UDP, dstIp, 9999, srcIp, 50001)) + .isSameInstanceAs(differentSrcPort) + assertThat(manager.getSession(SessionProtocol.UDP, dstIp, 8888, srcIp, 50000)) + .isSameInstanceAs(differentDstPort) + } + @Test fun `closeSession removes the session and closes its channel`() { val session = createUdp() diff --git a/app/src/test/java/tech/httptoolkit/android/vpn/SourceAddressForwardingTest.kt b/app/src/test/java/tech/httptoolkit/android/vpn/SourceAddressForwardingTest.kt new file mode 100644 index 0000000..e1bbfa7 --- /dev/null +++ b/app/src/test/java/tech/httptoolkit/android/vpn/SourceAddressForwardingTest.kt @@ -0,0 +1,204 @@ +package tech.httptoolkit.android.vpn + +import android.app.Application +import com.google.common.truth.Truth.assertThat +import org.distrinet.lanshield.database.model.LANFlow +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.ParameterizedRobolectricTestRunner +import org.robolectric.annotation.Config +import tech.httptoolkit.android.vpn.transport.ip.IPAddress +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * The tun interface is assigned a synthetic VPN address (10.215.173.1), but because the + * kernel picks a packet's source address from the destination, intercepted packets arrive + * with two different source IPs: + * + * - the VPN-tun IP (10.215.173.1) for destinations outside the device's own subnet, and + * - the device's real wlan IP (e.g. 192.168.1.100) for destinations on the local subnet. + * + * The forwarding engine handles the source IP opaquely, so both forms must round-trip + * symmetrically: the peer's reply must return to the TUN addressed to whatever source the + * client used, and the recorded [LANFlow.localEndpoint] must carry that same source. These + * parameterized tests lock that in for both source-IP forms, for UDP and TCP. + * + * Only the client source IP is varied; the peer stays on loopback (127.0.0.1) since a JVM + * test can't bind a peer on a real 192.168.x subnet and the engine connects to the + * destination regardless of source. + */ +@RunWith(ParameterizedRobolectricTestRunner::class) +@Config(sdk = [34], application = Application::class) +class SourceAddressForwardingTest( + private val caseLabel: String, + private val clientIp: String, +) { + + private lateinit var harness: ForwardingTestHarness + + private val clientPort = 50000 + private val peerIp = "127.0.0.1" + + @Before + fun setUp() { + harness = ForwardingTestHarness() + } + + @After + fun tearDown() { + harness.close() + } + + // --- UDP ----------------------------------------------------------------- + + @Test + fun `udp reply returns to the client source ip and the flow records it`() { + val peer = DatagramSocket(0, InetAddress.getByName(peerIp)).apply { soTimeout = 3000 } + val peerPort = peer.localPort + try { + harness.feed( + TestPackets.udpPacket(clientIp, clientPort, peerIp, peerPort, "ping".toByteArray()) + ) + + // egress: the real peer receives the exact payload + val received = DatagramPacket(ByteArray(64), 64) + peer.receive(received) + assertThat(String(received.data, 0, received.length)).isEqualTo("ping") + + // ingress: the reply is emitted to the TUN addressed back to the client source IP + peer.send(DatagramPacket("pong".toByteArray(), 4, received.socketAddress)) + val (ip, udp, payload) = harness.parseUdp(harness.awaitTunPacket()) + assertThat(String(payload)).isEqualTo("pong") + assertThat(ip.sourceIP.toString()).isEqualTo(peerIp) + assertThat(ip.destinationIP.toString()).isEqualTo(clientIp) + assertThat(udp.sourcePort).isEqualTo(peerPort) + assertThat(udp.destinationPort).isEqualTo(clientPort) + + // the recorded flow's local endpoint carries the client source IP + val session = harness.await { harness.sessionByKey(sessionKey(SessionProtocol.UDP, peerPort)) } + val flowUuid = session.flow.uuid + assertThat(session.flow.transportLayerProtocol).isEqualTo("UDP") + + val persisted: LANFlow = harness.await { + harness.flowDao.getFlowById(flowUuid)?.takeIf { it.dataIngress >= 4 } + } + assertThat(persisted.localEndpoint.address.hostAddress).isEqualTo(clientIp) + assertThat(persisted.localEndpoint.port).isEqualTo(clientPort) + } finally { + peer.close() + } + } + + // --- TCP ----------------------------------------------------------------- + + @Test + fun `tcp reply returns to the client source ip and the flow records it`() { + val server = ServerSocket(0, 50, InetAddress.getByName(peerIp)) + val peerPort = server.localPort + val executor = Executors.newSingleThreadExecutor() + val clientIsn = 1000L + try { + val acceptedFuture = executor.submit { server.accept() } + + // 1. SYN -> SYN-ACK, addressed back to the client source IP + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn, ack = 0, flags = TestPackets.SYN, mss = 1460, + ) + ) + val synAck = harness.awaitTunPacketMatching { + val (_, tcp) = harness.parseTcp(it); tcp.isSYN && tcp.isACK + } + val (synAckIp, synAckTcp) = harness.parseTcp(synAck) + assertThat(synAckTcp.ackNumber).isEqualTo(clientIsn + 1) + assertThat(synAckIp.sourceIP.toString()).isEqualTo(peerIp) + assertThat(synAckIp.destinationIP.toString()).isEqualTo(clientIp) + val serverIsn = synAckTcp.sequenceNumber + + val session = harness.await { harness.sessionByKey(sessionKey(SessionProtocol.TCP, peerPort)) } + val flowUuid = session.flow.uuid + assertThat(session.flow.transportLayerProtocol).isEqualTo("TCP") + + // 2. ACK completes the handshake + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = serverIsn + 1, flags = TestPackets.ACK, + ) + ) + val accepted = acceptedFuture.get(3, TimeUnit.SECONDS) + accepted.soTimeout = 3000 + harness.await { harness.flowDao.getFlowById(flowUuid)?.takeIf { it.tcpEstablishedReached } } + + // 3. PSH+ACK with payload -> peer receives it (egress) + val request = "GET /" + harness.feed( + TestPackets.tcpPacket( + clientIp, clientPort, peerIp, peerPort, + seq = clientIsn + 1, ack = serverIsn + 1, + flags = TestPackets.PSH or TestPackets.ACK, payload = request.toByteArray(), + ) + ) + val fromClient = ByteArray(request.length) + readFully(accepted, fromClient) + assertThat(String(fromClient)).isEqualTo(request) + + // 4. Peer responds -> the data segment is relayed to the TUN, addressed to the client + val response = "HTTP/1.1 200" + accepted.getOutputStream().apply { write(response.toByteArray()); flush() } + val dataPacket = harness.awaitTunPacketMatching { packet -> + packet.size > 40 && TestPackets.payloadString(packet, response.length) == response + } + val (dataIp, _) = harness.parseTcp(dataPacket) + assertThat(dataIp.sourceIP.toString()).isEqualTo(peerIp) + assertThat(dataIp.destinationIP.toString()).isEqualTo(clientIp) + + // the recorded flow's local endpoint carries the client source IP + val persisted: LANFlow = harness.await { harness.flowDao.getFlowById(flowUuid) } + assertThat(persisted.localEndpoint.address.hostAddress).isEqualTo(clientIp) + assertThat(persisted.localEndpoint.port).isEqualTo(clientPort) + + accepted.close() + } finally { + executor.shutdownNow() + server.close() + } + } + + private fun sessionKey(protocol: SessionProtocol, peerPort: Int): String = + Session.getSessionKey( + protocol, + IPAddress(TestPackets.ip(peerIp)), peerPort, + IPAddress(TestPackets.ip(clientIp)), clientPort, + ) + + private fun readFully(socket: Socket, buffer: ByteArray) { + val input = socket.getInputStream() + var read = 0 + while (read < buffer.size) { + val n = input.read(buffer, read, buffer.size - read) + if (n < 0) break + read += n + } + } + + companion object { + @JvmStatic + @ParameterizedRobolectricTestRunner.Parameters(name = "{0}") + fun parameters(): Collection> = listOf( + // VPN-tun source: destination outside the device's own subnet. + arrayOf("vpn-tun-source", "10.215.173.1"), + // wlan source: destination on the device's own local subnet. + arrayOf("wlan-source", "192.168.1.100"), + ) + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0e6b116..a78d00f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -30,6 +30,10 @@ turbine = "1.2.1" androidxTestCore = "1.7.0" androidxArchCore = "2.2.0" truth = "1.4.5" +uiautomator = "2.3.0" +androidxTestRules = "1.7.0" +testOrchestrator = "1.5.1" +testServices = "1.5.0" [libraries] accompanist-permissions = { module = "com.google.accompanist:accompanist-permissions", version.ref = "accompanistPermissions" } @@ -74,6 +78,10 @@ androidx-core-testing = { module = "androidx.arch.core:core-testing", version.re androidx-test-core = { module = "androidx.test:core", version.ref = "androidxTestCore" } androidx-test-core-ktx = { module = "androidx.test:core-ktx", version.ref = "androidxTestCore" } truth = { module = "com.google.truth:truth", version.ref = "truth" } +androidx-test-uiautomator = { module = "androidx.test.uiautomator:uiautomator", version.ref = "uiautomator" } +androidx-test-rules = { module = "androidx.test:rules", version.ref = "androidxTestRules" } +androidx-test-orchestrator = { module = "androidx.test:orchestrator", version.ref = "testOrchestrator" } +androidx-test-services = { module = "androidx.test.services:test-services", version.ref = "testServices" } [bundles] unit-test = [