diff --git a/ServiceWiring.php b/ServiceWiring.php index b0c42b93..921d5b76 100644 --- a/ServiceWiring.php +++ b/ServiceWiring.php @@ -11,6 +11,7 @@ return [ 'OATHAuthModuleRegistry' => static function ( MediaWikiServices $services ): OATHAuthModuleRegistry { return new OATHAuthModuleRegistry( $services->getDBLoadBalancerFactory(), + $services->getObjectFactory(), ExtensionRegistry::getInstance()->getAttribute( 'OATHAuthModules' ), ); }, diff --git a/extension.json b/extension.json index 590c6190..0132ac72 100644 --- a/extension.json +++ b/extension.json @@ -17,7 +17,12 @@ "attributes": { "OATHAuth": { "Modules": { - "totp": "\\MediaWiki\\Extension\\OATHAuth\\Module\\TOTP::factory" + "totp": { + "class": "\\MediaWiki\\Extension\\OATHAuth\\Module\\TOTP", + "services": [ + "OATHUserRepository" + ] + } } } }, diff --git a/i18n/en.json b/i18n/en.json index f81e232c..c10d3805 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -66,7 +66,6 @@ "log-action-filter-oath-verify": "Checking if two-factor authentication is enabled", "log-action-filter-oath-disable-other": "Disabling two-factor authentication for another user", "oathauth-ui-no-module": "None enabled", - "oathauth-module-invalid": "The OATHAuth module that the user has registered is invalid.", "oathauth-module-totp-label": "TOTP (one-time token)", "oathauth-ui-manage": "Manage", "oathmanage": "Manage Two-factor authentication", diff --git a/i18n/qqq.json b/i18n/qqq.json index b54779bf..fbd03e48 100644 --- a/i18n/qqq.json +++ b/i18n/qqq.json @@ -81,7 +81,6 @@ "log-action-filter-oath-verify": "{{doc-log-action-filter-action|oath|verify}}", "log-action-filter-oath-disable-other": "{{doc-log-action-filter-action|oath|disable-other}}", "oathauth-ui-no-module": "User preference value for the type of two-factor authentication operation {{msg-mw|Log-action-filter-oath}} when no 2FA module is enabled.", - "oathauth-module-invalid": "Error message when the OATHAuth module registered by user is invalid", "oathauth-module-totp-label": "User preference value when the TOTP module is enabled", "oathauth-ui-manage": "Button on Special:Preferences, that leads to [[Special:OATHManage]]", "oathmanage": "{{doc-special|OATHManage}}", diff --git a/src/Auth/TOTPSecondaryAuthenticationProvider.php b/src/Auth/TOTPSecondaryAuthenticationProvider.php index 0791c773..42bca642 100644 --- a/src/Auth/TOTPSecondaryAuthenticationProvider.php +++ b/src/Auth/TOTPSecondaryAuthenticationProvider.php @@ -23,7 +23,7 @@ use MediaWiki\Auth\AuthenticationRequest; use MediaWiki\Auth\AuthenticationResponse; use MediaWiki\Auth\AuthManager; use MediaWiki\Extension\OATHAuth\Module\TOTP; -use MediaWiki\MediaWikiServices; +use MediaWiki\Extension\OATHAuth\OATHUserRepository; use MediaWiki\Message\Message; use MediaWiki\User\User; @@ -38,12 +38,11 @@ use MediaWiki\User\User; */ class TOTPSecondaryAuthenticationProvider extends AbstractSecondaryAuthenticationProvider { private TOTP $module; + private OATHUserRepository $userRepository; - /** - * @param TOTP $module - */ - public function __construct( TOTP $module ) { + public function __construct( TOTP $module, OATHUserRepository $userRepository ) { $this->module = $module; + $this->userRepository = $userRepository; } /** @@ -66,6 +65,12 @@ class TOTPSecondaryAuthenticationProvider extends AbstractSecondaryAuthenticatio * @return AuthenticationResponse */ public function beginSecondaryAuthentication( $user, array $reqs ) { + $authUser = $this->userRepository->findByUser( $user ); + + if ( !( $authUser->getModule() instanceof TOTP ) ) { + return AuthenticationResponse::newAbstain(); + } + return AuthenticationResponse::newUI( [ new TOTPAuthenticationRequest() ], wfMessage( 'oathauth-auth-ui' ), @@ -84,8 +89,7 @@ class TOTPSecondaryAuthenticationProvider extends AbstractSecondaryAuthenticatio wfMessage( 'oathauth-login-failed' ), 'error' ); } - $userRepo = MediaWikiServices::getInstance()->getService( 'OATHUserRepository' ); - $authUser = $userRepo->findByUser( $user ); + $authUser = $this->userRepository->findByUser( $user ); $token = $request->OATHToken; // Don't increase pingLimiter, just check for limit exceeded. diff --git a/src/Module/TOTP.php b/src/Module/TOTP.php index 1e120fe0..0c0dfb63 100644 --- a/src/Module/TOTP.php +++ b/src/Module/TOTP.php @@ -15,8 +15,10 @@ use MediaWiki\Extension\OATHAuth\Special\OATHManage; use MWException; class TOTP implements IModule { - public static function factory() { - return new static(); + private OATHUserRepository $userRepository; + + public function __construct( OATHUserRepository $userRepository ) { + $this->userRepository = $userRepository; } /** @inheritDoc */ @@ -49,7 +51,8 @@ class TOTP implements IModule { */ public function getSecondaryAuthProvider() { return new TOTPSecondaryAuthenticationProvider( - $this + $this, + $this->userRepository ); } diff --git a/src/OATHAuthModuleRegistry.php b/src/OATHAuthModuleRegistry.php index b3a477ef..c8bf0a20 100644 --- a/src/OATHAuthModuleRegistry.php +++ b/src/OATHAuthModuleRegistry.php @@ -21,36 +21,50 @@ namespace MediaWiki\Extension\OATHAuth; use InvalidArgumentException; +use Wikimedia\ObjectFactory\ObjectFactory; use Wikimedia\Rdbms\IConnectionProvider; class OATHAuthModuleRegistry { private IConnectionProvider $dbProvider; + private ObjectFactory $objectFactory; /** @var array */ - private $modules; + private array $modules; /** @var array|null */ - private $moduleIds; + private ?array $moduleIds = null; public function __construct( IConnectionProvider $dbProvider, + ObjectFactory $objectFactory, array $modules ) { $this->dbProvider = $dbProvider; + $this->objectFactory = $objectFactory; $this->modules = $modules; } - public function getModuleByKey( string $key ): ?IModule { - if ( isset( $this->getModules()[$key] ) ) { - $module = call_user_func_array( $this->getModules()[$key], [] ); - if ( !$module instanceof IModule ) { - return null; - } - return $module; + public function moduleExists( string $moduleKey ): bool { + return isset( $this->getModules()[$moduleKey] ); + } + + public function getModuleByKey( string $key ): IModule { + if ( !isset( $this->getModules()[$key] ) ) { + throw new InvalidArgumentException( "No such two-factor module $key" ); } - return null; + $data = $this->getModules()[$key]; + if ( is_string( $data ) ) { + $module = call_user_func_array( $this->getModules()[$key], [] ); + } else { + $module = $this->objectFactory->createObject( + $data, + [ 'assertClass' => IModule::class ] + ); + } + + return $module; } /** @@ -61,11 +75,7 @@ class OATHAuthModuleRegistry { public function getAllModules(): array { $modules = []; foreach ( $this->getModules() as $key => $callback ) { - $module = $this->getModuleByKey( $key ); - if ( !( $module instanceof IModule ) ) { - continue; - } - $modules[$key] = $module; + $modules[$key] = $this->getModuleByKey( $key ); } return $modules; } diff --git a/src/OATHUserRepository.php b/src/OATHUserRepository.php index 85ea997e..251883b4 100644 --- a/src/OATHUserRepository.php +++ b/src/OATHUserRepository.php @@ -77,6 +77,7 @@ class OATHUserRepository implements LoggerAwareInterface { ->centralIdFromLocalUser( $user ); $oathUser = new OATHUser( $user, $uid ); $this->loadKeysFromDatabase( $oathUser ); + $this->cache->set( $user->getName(), $oathUser ); } return $oathUser; @@ -164,14 +165,13 @@ class OATHUserRepository implements LoggerAwareInterface { ); } - $userId = $this->centralIdLookupFactory->getLookup()->centralIdFromLocalUser( $user->getUser() ); $moduleId = $this->moduleRegistry->getModuleId( $module->getName() ); $dbw = $this->dbProvider->getPrimaryDatabase( 'virtual-oathauth' ); $dbw->newInsertQueryBuilder() ->insertInto( 'oathauth_devices' ) ->row( [ - 'oad_user' => $userId, + 'oad_user' => $user->getCentralId(), 'oad_type' => $moduleId, 'oad_data' => FormatJson::encode( $keyData ), 'oad_created' => $dbw->timestamp(), diff --git a/src/Special/OATHManage.php b/src/Special/OATHManage.php index c1b17383..0b675cdc 100644 --- a/src/Special/OATHManage.php +++ b/src/Special/OATHManage.php @@ -53,10 +53,7 @@ class OATHManage extends SpecialPage { */ protected $action; - /** - * @var IModule|null - */ - protected $requestedModule; + protected ?IModule $requestedModule; /** * Initializes a page to manage available 2FA modules @@ -147,7 +144,9 @@ class OATHManage extends SpecialPage { private function setModule(): void { $moduleKey = $this->getRequest()->getVal( 'module', '' ); - $this->requestedModule = $this->moduleRegistry->getModuleByKey( $moduleKey ); + $this->requestedModule = ( $moduleKey && $this->moduleRegistry->moduleExists( $moduleKey ) ) + ? $this->moduleRegistry->getModuleByKey( $moduleKey ) + : null; } private function addEnabledHTML(): void { diff --git a/tests/phpunit/OATHAuthModuleRegistryTest.php b/tests/phpunit/OATHAuthModuleRegistryTest.php index b951bf48..4897de28 100644 --- a/tests/phpunit/OATHAuthModuleRegistryTest.php +++ b/tests/phpunit/OATHAuthModuleRegistryTest.php @@ -19,6 +19,7 @@ */ use MediaWiki\Extension\OATHAuth\OATHAuthModuleRegistry; +use Wikimedia\ObjectFactory\ObjectFactory; use Wikimedia\Rdbms\IConnectionProvider; /** @@ -26,10 +27,7 @@ use Wikimedia\Rdbms\IConnectionProvider; * @group Database */ class OATHAuthModuleRegistryTest extends MediaWikiIntegrationTestCase { - /** - * @covers \MediaWiki\Extension\OATHAuth\OATHAuthModuleRegistry::getModuleIds - */ - public function testGetModuleIds() { + private function makeTestRegistry(): OATHAuthModuleRegistry { $this->getDb()->newInsertQueryBuilder() ->insertInto( 'oathauth_types' ) ->row( [ 'oat_name' => 'first' ] ) @@ -37,17 +35,34 @@ class OATHAuthModuleRegistryTest extends MediaWikiIntegrationTestCase { ->execute(); $database = $this->createMock( IConnectionProvider::class ); - $database->method( 'getPrimaryDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->db ); - $database->method( 'getReplicaDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->db ); + $database->method( 'getPrimaryDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->getDb() ); + $database->method( 'getReplicaDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->getDb() ); - $registry = new OATHAuthModuleRegistry( + return new OATHAuthModuleRegistry( $database, + $this->createNoOpMock( ObjectFactory::class ), [ 'first' => 'does not matter', 'second' => 'does not matter', 'third' => 'does not matter', ] ); + } + + /** + * @covers \MediaWiki\Extension\OATHAuth\OATHAuthModuleRegistry::moduleExists + */ + public function testModuleExists() { + $registry = $this->makeTestRegistry(); + $this->assertTrue( $registry->moduleExists( 'first' ) ); + $this->assertFalse( $registry->moduleExists( 'nonexistent' ) ); + } + + /** + * @covers \MediaWiki\Extension\OATHAuth\OATHAuthModuleRegistry::getModuleIds + */ + public function testGetModuleIds() { + $registry = $this->makeTestRegistry(); $this->assertEquals( [ 'first', 'second', 'third' ], diff --git a/tests/phpunit/integration/OATHUserRepositoryTest.php b/tests/phpunit/integration/OATHUserRepositoryTest.php index 747af1a3..dec8ae3b 100644 --- a/tests/phpunit/integration/OATHUserRepositoryTest.php +++ b/tests/phpunit/integration/OATHUserRepositoryTest.php @@ -47,8 +47,8 @@ class OATHUserRepositoryTest extends MediaWikiIntegrationTestCase { $user = $this->getTestUser()->getUser(); $dbProvider = $this->createMock( IConnectionProvider::class ); - $dbProvider->method( 'getPrimaryDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->db ); - $dbProvider->method( 'getReplicaDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->db ); + $dbProvider->method( 'getPrimaryDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->getDb() ); + $dbProvider->method( 'getReplicaDatabase' )->with( 'virtual-oathauth' )->willReturn( $this->getDb() ); $moduleRegistry = OATHAuthServices::getInstance( $this->getServiceContainer() )->getModuleRegistry(); $module = $moduleRegistry->getModuleByKey( 'totp' ); diff --git a/tests/phpunit/integration/Special/OATHManageTest.php b/tests/phpunit/integration/Special/OATHManageTest.php new file mode 100644 index 00000000..a5aaac7e --- /dev/null +++ b/tests/phpunit/integration/Special/OATHManageTest.php @@ -0,0 +1,54 @@ + + * @group Database + * @coversDefaultClass \MediaWiki\Extension\OATHAuth\Special\OATHManage + */ +class OATHManageTest extends SpecialPageTestBase { + protected function newSpecialPage() { + $services = OATHAuthServices::getInstance( $this->getServiceContainer() ); + return new OATHManage( + $services->getUserRepository(), + $services->getModuleRegistry(), + ); + } + + /** + * @covers ::execute + */ + public function testPageLoads() { + $this->executeSpecialPage( + '', + null, + null, + $this->getTestUser()->getAuthority(), + ); + + $this->addToAssertionCount( 1 ); + } +}