diff --git a/ServiceWiring.php b/ServiceWiring.php index ee07228c..9df370b6 100644 --- a/ServiceWiring.php +++ b/ServiceWiring.php @@ -18,7 +18,10 @@ return [ ); }, 'OATHAuthModuleRegistry' => static function ( MediaWikiServices $services ) { - return new OATHAuthModuleRegistry(); + return new OATHAuthModuleRegistry( + $services->getService( 'OATHAuthDatabase' ), + ExtensionRegistry::getInstance()->getAttribute( 'OATHAuthModules' ), + ); }, 'OATHUserRepository' => static function ( MediaWikiServices $services ) { return new OATHUserRepository( diff --git a/extension.json b/extension.json index 34babb91..6e43167f 100644 --- a/extension.json +++ b/extension.json @@ -21,7 +21,8 @@ } }, "AutoloadNamespaces": { - "MediaWiki\\Extension\\OATHAuth\\": "src/" + "MediaWiki\\Extension\\OATHAuth\\": "src/", + "MediaWiki\\Extension\\OATHAuth\\Maintenance\\": "maintenance/" }, "TestAutoloadNamespaces": { "MediaWiki\\Extension\\OATHAuth\\Tests\\": "tests/phpunit/" @@ -84,6 +85,9 @@ }, "OATHRequiredForGroups": { "value": [] + }, + "OATHAuthMultipleDevicesMigrationStage": { + "value": 768 } }, "ResourceModules": { diff --git a/maintenance/UpdateForMultipleDevicesSupport.php b/maintenance/UpdateForMultipleDevicesSupport.php new file mode 100644 index 00000000..44d5b8fd --- /dev/null +++ b/maintenance/UpdateForMultipleDevicesSupport.php @@ -0,0 +1,127 @@ + + */ +class UpdateForMultipleDevicesSupport extends LoggedUpdateMaintenance { + public function __construct() { + parent::__construct(); + $this->requireExtension( 'OATHAuth' ); + $this->setBatchSize( 500 ); + } + + protected function doDBUpdates() { + $database = OATHAuthServices::getInstance()->getDatabase(); + $dbw = $database->getDB( DB_PRIMARY ); + + $maxId = $dbw->newSelectQueryBuilder() + ->select( 'MAX(id)' ) + ->from( 'oathauth_users' ) + ->caller( __METHOD__ ) + ->fetchField(); + + $typeIds = OATHAuthServices::getInstance()->getModuleRegistry()->getModuleIds(); + + $updated = 0; + + for ( $min = 0; $min <= $maxId; $min += $this->getBatchSize() ) { + $max = $min + $this->getBatchSize(); + $this->output( "Now processing rows with id between $min and $max... (updated $updated users so far)\n" ); + + $res = $dbw->newSelectQueryBuilder() + ->select( [ + 'id', + 'module', + 'data', + ] ) + ->from( 'oathauth_users' ) + ->where( [ + $dbw->buildComparison( '>=', [ 'id' => $min ] ), + $dbw->buildComparison( '<', [ 'id' => $max ] ), + ] ) + ->caller( __METHOD__ ) + ->fetchResultSet(); + + $toDelete = []; + $toAdd = []; + + foreach ( $res as $row ) { + $decodedData = FormatJson::decode( $row->data, true ); + if ( isset( $decodedData['keys'] ) ) { + $toDelete[] = (int)$row->id; + + $updated += 1; + + foreach ( $decodedData['keys'] as $keyData ) { + $toAdd[] = [ + 'oad_user' => (int)$row->id, + 'oad_type' => $typeIds[$row->module], + 'oad_data' => FormatJson::encode( $keyData ), + ]; + } + } + } + + if ( $toAdd ) { + $dbw->startAtomic( __METHOD__ ); + $dbw->insert( + 'oathauth_devices', + $toAdd, + __METHOD__ + ); + $dbw->delete( + 'oathauth_users', + [ 'id' => $toDelete ], + __METHOD__ + ); + $dbw->endAtomic( __METHOD__ ); + } + + $database->waitForReplication(); + } + + $this->output( "Done, updated data for $updated users.\n" ); + return true; + } + + /** + * @return string + */ + protected function getUpdateKey() { + return __CLASS__; + } +} + +$maintClass = UpdateForMultipleDevicesSupport::class; +require_once RUN_MAINTENANCE_IF_MAIN; diff --git a/sql/mysql/tables-generated.sql b/sql/mysql/tables-generated.sql index 93833887..c559417a 100644 --- a/sql/mysql/tables-generated.sql +++ b/sql/mysql/tables-generated.sql @@ -1,10 +1,22 @@ -- This file is automatically generated using maintenance/generateSchemaSql.php. --- Source: ./tables.json +-- Source: sql/tables.json -- Do not modify this file directly. -- See https://www.mediawiki.org/wiki/Manual:Schema_changes -CREATE TABLE /*_*/oathauth_users ( - id INT NOT NULL, - module VARCHAR(255) NOT NULL, - data BLOB DEFAULT NULL, - PRIMARY KEY(id) +CREATE TABLE /*_*/oathauth_types ( + oat_id INT AUTO_INCREMENT NOT NULL, + oat_name VARBINARY(255) NOT NULL, + UNIQUE INDEX oat_name (oat_name), + PRIMARY KEY(oat_id) +) /*$wgDBTableOptions*/; + + +CREATE TABLE /*_*/oathauth_devices ( + oad_id INT AUTO_INCREMENT NOT NULL, + oad_user INT NOT NULL, + oad_type INT NOT NULL, + oad_name VARBINARY(255) DEFAULT NULL, + oad_created BINARY(14) DEFAULT NULL, + oad_data BLOB DEFAULT NULL, + INDEX oad_user (oad_user), + PRIMARY KEY(oad_id) ) /*$wgDBTableOptions*/; diff --git a/sql/postgres/tables-generated.sql b/sql/postgres/tables-generated.sql index 2beb6a33..7f63cfc6 100644 --- a/sql/postgres/tables-generated.sql +++ b/sql/postgres/tables-generated.sql @@ -1,10 +1,24 @@ -- This file is automatically generated using maintenance/generateSchemaSql.php. --- Source: ./tables.json +-- Source: sql/tables.json -- Do not modify this file directly. -- See https://www.mediawiki.org/wiki/Manual:Schema_changes -CREATE TABLE oathauth_users ( - id INT NOT NULL, - module VARCHAR(255) NOT NULL, - data TEXT DEFAULT NULL, - PRIMARY KEY(id) +CREATE TABLE oathauth_types ( + oat_id SERIAL NOT NULL, + oat_name TEXT NOT NULL, + PRIMARY KEY(oat_id) ); + +CREATE UNIQUE INDEX oat_name ON oathauth_types (oat_name); + + +CREATE TABLE oathauth_devices ( + oad_id SERIAL NOT NULL, + oad_user INT NOT NULL, + oad_type INT NOT NULL, + oad_name TEXT DEFAULT NULL, + oad_created TIMESTAMPTZ DEFAULT NULL, + oad_data TEXT DEFAULT NULL, + PRIMARY KEY(oad_id) +); + +CREATE INDEX oad_user ON oathauth_devices (oad_user); diff --git a/sql/sqlite/tables-generated.sql b/sql/sqlite/tables-generated.sql index 9140bd0c..51138015 100644 --- a/sql/sqlite/tables-generated.sql +++ b/sql/sqlite/tables-generated.sql @@ -1,10 +1,20 @@ -- This file is automatically generated using maintenance/generateSchemaSql.php. --- Source: ./tables.json +-- Source: sql/tables.json -- Do not modify this file directly. -- See https://www.mediawiki.org/wiki/Manual:Schema_changes -CREATE TABLE /*_*/oathauth_users ( - id INTEGER NOT NULL, - module VARCHAR(255) NOT NULL, - data BLOB DEFAULT NULL, - PRIMARY KEY(id) +CREATE TABLE /*_*/oathauth_types ( + oat_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + oat_name BLOB NOT NULL ); + +CREATE UNIQUE INDEX oat_name ON /*_*/oathauth_types (oat_name); + + +CREATE TABLE /*_*/oathauth_devices ( + oad_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + oad_user INTEGER NOT NULL, oad_type INTEGER NOT NULL, + oad_name BLOB DEFAULT NULL, oad_created BLOB DEFAULT NULL, + oad_data BLOB DEFAULT NULL +); + +CREATE INDEX oad_user ON /*_*/oathauth_devices (oad_user); diff --git a/sql/tables.json b/sql/tables.json index b82b92d9..f449b86d 100644 --- a/sql/tables.json +++ b/sql/tables.json @@ -1,27 +1,78 @@ [ { - "name": "oathauth_users", + "name": "oathauth_types", + "comment": "Possible authentication device types", "columns": [ { - "name": "id", + "name": "oat_id", + "comment": "Unique ID of this device type", + "type": "integer", + "options": { "autoincrement": true, "notnull": false } + }, + { + "name": "oat_name", + "comment": "Internal name of this device type, matching the keys of attributes.OATHAuth.Modules in extension.json", + "type": "binary", + "options": { "notnull": true, "length": 255 } + } + ], + "indexes": [ + { + "name": "oat_name", + "columns": [ "oat_name" ], + "unique": true + } + ], + "pk": [ "oat_id" ] + }, + { + "name": "oathauth_devices", + "comment": "Enrolled authentication devices", + "columns": [ + { + "name": "oad_id", + "comment": "Unique ID of this authentication device", + "type": "integer", + "options": { "autoincrement": true, "notnull": false } + }, + { + "name": "oad_user", "comment": "User ID", "type": "integer", "options": { "notnull": true } }, { - "name": "module", - "comment": "Module user has selected", - "type": "string", - "options": { "notnull": true, "length": 255 } + "name": "oad_type", + "comment": "Device type ID, references the oauthauth_types table", + "type": "integer", + "options": { "notnull": true } }, { - "name": "data", + "name": "oad_name", + "comment": "User-specified name of this device", + "type": "binary", + "options": { "notnull": false, "length": 255 } + }, + { + "name": "oad_created", + "comment": "Timestamp when this authentication device was created", + "type": "mwtimestamp", + "options": { "notnull": false } + }, + { + "name": "oad_data", "comment": "Data", "type": "blob", "options": { "length": 65530, "notnull": false } } ], - "indexes": [], - "pk": [ "id" ] + "indexes": [ + { + "name": "oad_user", + "columns": [ "oad_user" ], + "unique": false + } + ], + "pk": [ "oad_id" ] } ] diff --git a/src/Hook/UpdateTables.php b/src/Hook/UpdateTables.php index 51e40f27..ef6de247 100644 --- a/src/Hook/UpdateTables.php +++ b/src/Hook/UpdateTables.php @@ -5,6 +5,7 @@ namespace MediaWiki\Extension\OATHAuth\Hook; use ConfigException; use DatabaseUpdater; use FormatJson; +use MediaWiki\Extension\OATHAuth\Maintenance\UpdateForMultipleDevicesSupport; use MediaWiki\Installer\Hook\LoadExtensionSchemaUpdatesHook; use MediaWiki\MediaWikiServices; use Wikimedia\Rdbms\IMaintainableDatabase; @@ -16,50 +17,65 @@ class UpdateTables implements LoadExtensionSchemaUpdatesHook { */ public function onLoadExtensionSchemaUpdates( $updater ) { $type = $updater->getDB()->getType(); - $typePath = dirname( __DIR__, 2 ) . "/sql/{$type}"; + $baseDir = dirname( __DIR__, 2 ); + $typePath = "$baseDir/sql/$type"; $updater->addExtensionTable( - 'oathauth_users', + 'oathauth_types', "$typePath/tables-generated.sql" ); - switch ( $type ) { - case 'mysql': - case 'sqlite': - // 1.34 - $updater->addExtensionField( - 'oathauth_users', - 'module', - "$typePath/patch-add_generic_fields.sql" - ); + // If the old table exists, ensure it's up-to-date so the migration + // from the old schema to the new one can be done properly. + if ( $updater->tableExists( 'oathauth_users' ) ) { + switch ( $type ) { + case 'mysql': + case 'sqlite': + // 1.34 + $updater->addExtensionField( + 'oathauth_users', + 'module', + "$typePath/patch-add_generic_fields.sql" + ); - $updater->addExtensionUpdate( - [ [ __CLASS__, 'schemaUpdateSubstituteForGenericFields' ] ] - ); - $updater->dropExtensionField( - 'oathauth_users', - 'secret', - "$typePath/patch-remove_module_specific_fields.sql" - ); + $updater->addExtensionUpdate( + [ [ __CLASS__, 'schemaUpdateSubstituteForGenericFields' ] ] + ); + $updater->dropExtensionField( + 'oathauth_users', + 'secret', + "$typePath/patch-remove_module_specific_fields.sql" + ); - $updater->addExtensionUpdate( - [ [ __CLASS__, 'schemaUpdateTOTPToMultipleKeys' ] ] - ); + $updater->addExtensionUpdate( + [ [ __CLASS__, 'schemaUpdateTOTPToMultipleKeys' ] ] + ); - $updater->addExtensionUpdate( - [ [ __CLASS__, 'schemaUpdateTOTPScratchTokensToArray' ] ] - ); + $updater->addExtensionUpdate( + [ [ __CLASS__, 'schemaUpdateTOTPScratchTokensToArray' ] ] + ); - break; + break; - case 'postgres': - // 1.38 - $updater->modifyExtensionTable( - 'oathauth_users', - "$typePath/patch-oathauth_users-drop-oathauth_users_id_seq.sql" - ); - break; + case 'postgres': + // 1.38 + $updater->modifyExtensionTable( + 'oathauth_users', + "$typePath/patch-oathauth_users-drop-oathauth_users_id_seq.sql" + ); + break; + } + + $updater->addExtensionUpdate( [ + 'runMaintenance', + UpdateForMultipleDevicesSupport::class, + "$baseDir/maintenance/UpdateForMultipleDevicesSupport.php" + ] ); + + $updater->dropExtensionTable( 'oathauth_users' ); } + + // add new updates here } /** diff --git a/src/IModule.php b/src/IModule.php index b8699f19..9479c8bf 100644 --- a/src/IModule.php +++ b/src/IModule.php @@ -26,12 +26,6 @@ interface IModule { */ public function newKey( array $data ); - /** - * @param OATHUser $user - * @return array - */ - public function getDataFromUser( OATHUser $user ); - /** * @return AbstractSecondaryAuthenticationProvider */ diff --git a/src/Key/TOTPKey.php b/src/Key/TOTPKey.php index f45aa92a..e357526d 100644 --- a/src/Key/TOTPKey.php +++ b/src/Key/TOTPKey.php @@ -187,7 +187,7 @@ class TOTPKey implements IAuthKey { foreach ( $this->scratchTokens as $i => $scratchToken ) { if ( hash_equals( $token, $scratchToken ) ) { // If we used a scratch token, remove it from the scratch token list. - // This is saved below via OATHUserRepository::persist, TOTP::getDataFromUser. + // This is saved below via OATHUserRepository::persist array_splice( $this->scratchTokens, $i, 1 ); $logger->info( 'OATHAuth user {user} used a scratch token from {clientip}', [ @@ -200,8 +200,8 @@ class TOTPKey implements IAuthKey { /** @var OATHUserRepository $userRepo */ $userRepo = MediaWikiServices::getInstance()->getService( 'OATHUserRepository' ); - $user->addKey( $this ); - $user->setModule( $module ); + // TODO: support for multiple keys + $user->setKeys( [ $this ] ); $userRepo->persist( $user, $clientIP ); return true; diff --git a/src/Module/TOTP.php b/src/Module/TOTP.php index 3238c1df..3246404d 100644 --- a/src/Module/TOTP.php +++ b/src/Module/TOTP.php @@ -54,21 +54,6 @@ class TOTP implements IModule { return TOTPKey::newFromArray( $data ); } - /** - * @param OATHUser $user - * @return array - * @throws MWException - */ - public function getDataFromUser( OATHUser $user ) { - $key = $user->getFirstKey(); - if ( !( $key instanceof TOTPKey ) ) { - throw new MWException( 'oathauth-invalid-key-type' ); - } - return [ - 'keys' => [ $key->jsonSerialize() ] - ]; - } - /** * @return AbstractSecondaryAuthenticationProvider */ diff --git a/src/OATHAuthDatabase.php b/src/OATHAuthDatabase.php index 5b496243..1e57d11b 100644 --- a/src/OATHAuthDatabase.php +++ b/src/OATHAuthDatabase.php @@ -59,4 +59,10 @@ class OATHAuthDatabase { $db = $this->options->get( 'OATHAuthDatabase' ); return $this->lbFactory->getMainLB( $db )->getConnectionRef( $index, [], $db ); } + + public function waitForReplication(): void { + $this->lbFactory->waitForReplication( [ + 'domain' => $this->options->get( 'OATHAuthDatabase' ), + ] ); + } } diff --git a/src/OATHAuthModuleRegistry.php b/src/OATHAuthModuleRegistry.php index e2b0a4d4..8d19d0e1 100644 --- a/src/OATHAuthModuleRegistry.php +++ b/src/OATHAuthModuleRegistry.php @@ -20,21 +20,38 @@ namespace MediaWiki\Extension\OATHAuth; -use ExtensionRegistry; +use Exception; class OATHAuthModuleRegistry { + /** @var OATHAuthDatabase */ + private OATHAuthDatabase $database; + + /** @var array */ + private $modules; + /** @var array|null */ - private $modules = null; + private $moduleIds; + + /** + * @param OATHAuthDatabase $database + * @param array $modules + */ + public function __construct( + OATHAuthDatabase $database, + array $modules + ) { + $this->database = $database; + $this->modules = $modules; + } /** * @param string $key * @return IModule|null */ public function getModuleByKey( string $key ): ?IModule { - $this->collectModules(); - if ( isset( $this->modules[$key] ) ) { - $module = call_user_func_array( $this->modules[$key], [] ); + if ( isset( $this->getModules()[$key] ) ) { + $module = call_user_func_array( $this->getModules()[$key], [] ); if ( !$module instanceof IModule ) { return null; } @@ -50,10 +67,8 @@ class OATHAuthModuleRegistry { * @return IModule[] */ public function getAllModules(): array { - $this->collectModules(); - $modules = []; - foreach ( $this->modules as $key => $callback ) { + foreach ( $this->getModules() as $key => $callback ) { $module = $this->getModuleByKey( $key ); if ( !( $module instanceof IModule ) ) { continue; @@ -63,11 +78,66 @@ class OATHAuthModuleRegistry { return $modules; } - private function collectModules() { - if ( $this->modules !== null ) { - return; + /** + * Returns the numerical ID for the module with the specified key. + * @param string $key + * @return int + */ + public function getModuleId( string $key ): int { + $ids = $this->getModuleIds(); + if ( isset( $ids[$key] ) ) { + return $ids[$key]; } - $this->modules = ExtensionRegistry::getInstance()->getAttribute( 'OATHAuthModules' ); + throw new Exception( "Module $key does not seem to exist" ); + } + + /** + * @return array + */ + public function getModuleIds(): array { + if ( $this->moduleIds === null ) { + $this->moduleIds = $this->getModuleIdsFromDatabase( DB_REPLICA ); + } + + $missing = array_diff( + array_keys( $this->getModules() ), + array_keys( $this->moduleIds ) + ); + + if ( $missing ) { + $rows = []; + foreach ( $missing as $name ) { + $rows[] = [ 'oat_name' => $name ]; + } + + $this->database + ->getDB( DB_PRIMARY ) + ->insert( 'oathauth_types', $rows, __METHOD__ ); + $this->moduleIds = $this->getModuleIdsFromDatabase( DB_PRIMARY ); + } + + return $this->moduleIds; + } + + private function getModuleIdsFromDatabase( int $index ): array { + $ids = []; + + $rows = $this->database->getDB( $index ) + ->newSelectQueryBuilder() + ->select( [ 'oat_id', 'oat_name' ] ) + ->from( 'oathauth_types' ) + ->caller( __METHOD__ ) + ->fetchResultSet(); + + foreach ( $rows as $row ) { + $ids[$row->oat_name] = (int)$row->oat_id; + } + + return $ids; + } + + private function getModules(): array { + return $this->modules; } } diff --git a/src/OATHAuthServices.php b/src/OATHAuthServices.php index 084e4e4e..f06fdb8c 100644 --- a/src/OATHAuthServices.php +++ b/src/OATHAuthServices.php @@ -54,4 +54,11 @@ class OATHAuthServices { public function getDatabase(): OATHAuthDatabase { return $this->services->getService( 'OATHAuthDatabase' ); } + + /** + * @return OATHAuthModuleRegistry + */ + public function getModuleRegistry(): OATHAuthModuleRegistry { + return $this->services->getService( 'OATHAuthModuleRegistry' ); + } } diff --git a/src/OATHUserRepository.php b/src/OATHUserRepository.php index 5552f81f..319518c5 100644 --- a/src/OATHUserRepository.php +++ b/src/OATHUserRepository.php @@ -27,6 +27,7 @@ use MWException; use Psr\Log\LoggerAwareInterface; use Psr\Log\LoggerInterface; use RequestContext; +use RuntimeException; use User; class OATHUserRepository implements LoggerAwareInterface { @@ -50,11 +51,12 @@ class OATHUserRepository implements LoggerAwareInterface { /** @internal Only public for service wiring use. */ public const CONSTRUCTOR_OPTIONS = [ - 'OATHAuthDatabase', + 'OATHAuthMultipleDevicesMigrationStage', ]; /** * OATHUserRepository constructor. + * * @param ServiceOptions $options * @param OATHAuthDatabase $database * @param BagOStuff $cache @@ -99,25 +101,61 @@ class OATHUserRepository implements LoggerAwareInterface { $uid = $this->centralIdLookupFactory->getLookup() ->centralIdFromLocalUser( $user ); - $res = $this->database->getDB( DB_REPLICA )->selectRow( - 'oathauth_users', - [ 'module', 'data' ], - [ 'id' => $uid ], - __METHOD__ - ); - if ( $res ) { - $module = $this->moduleRegistry->getModuleByKey( $res->module ); + + $moduleId = null; + $keys = []; + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_READ_NEW ) { + $res = $this->database->getDB( DB_REPLICA )->newSelectQueryBuilder() + ->select( [ + 'oad_data', + 'oat_name', + ] ) + ->from( 'oathauth_devices' ) + ->join( 'oathauth_types', null, [ 'oat_id = oad_type' ] ) + ->where( [ 'oad_user' => $uid ] ) + ->caller( __METHOD__ ) + ->fetchResultSet(); + + foreach ( $res as $row ) { + if ( $moduleId && $row->oat_name !== $moduleId ) { + // Not supported by current application-layer code. + throw new RuntimeException( "user {$uid} has multiple different oauth modules defined" ); + } + + $moduleId = $row->oat_name; + $keys[] = FormatJson::decode( $row->oad_data, true ); + } + } + + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_READ_OLD && !$moduleId ) { + $res = $this->database->getDB( DB_REPLICA )->selectRow( + 'oathauth_users', + [ 'module', 'data' ], + [ 'id' => $uid ], + __METHOD__ + ); + + if ( $res ) { + $moduleId = $res->module; + $decodedData = FormatJson::decode( $res->data, true ); + + if ( is_array( $decodedData['keys'] ) ) { + $keys = $decodedData['keys']; + } + } + } + + if ( $moduleId ) { + $module = $this->moduleRegistry->getModuleByKey( $moduleId ); if ( $module === null ) { throw new MWException( 'oathauth-module-invalid' ); } $oathUser->setModule( $module ); - $decodedData = FormatJson::decode( $res->data, true ); - if ( is_array( $decodedData['keys'] ) ) { - foreach ( $decodedData['keys'] as $keyData ) { - $key = $module->newKey( $keyData ); - $oathUser->addKey( $key ); - } + + foreach ( $keys as $keyData ) { + $key = $module->newKey( $keyData ); + $oathUser->addKey( $key ); } } @@ -137,19 +175,52 @@ class OATHUserRepository implements LoggerAwareInterface { $clientInfo = RequestContext::getMain()->getRequest()->getIP(); } $prevUser = $this->findByUser( $user->getUser() ); - $data = $user->getModule()->getDataFromUser( $user ); + $userId = $this->centralIdLookupFactory->getLookup()->centralIdFromLocalUser( $user->getUser() ); - $this->database->getDB( DB_PRIMARY )->replace( - 'oathauth_users', - 'id', - [ - 'id' => $this->centralIdLookupFactory->getLookup() - ->centralIdFromLocalUser( $user->getUser() ), - 'module' => $user->getModule()->getName(), - 'data' => FormatJson::encode( $data ) - ], - __METHOD__ - ); + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_WRITE_NEW ) { + $moduleId = $this->moduleRegistry->getModuleId( $user->getModule()->getName() ); + $rows = []; + foreach ( $user->getKeys() as $key ) { + $rows[] = [ + 'oad_user' => $userId, + 'oad_type' => $moduleId, + 'oad_data' => FormatJson::encode( $key->jsonSerialize() ) + ]; + } + + // TODO: only update changed rows + $dbw = $this->database->getDB( DB_PRIMARY ); + $dbw->delete( + 'oathauth_devices', + [ 'oad_user' => $userId ], + __METHOD__ + ); + $dbw->insert( + 'oathauth_devices', + $rows, + __METHOD__ + ); + } + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_WRITE_OLD ) { + $data = [ + 'keys' => [] + ]; + + foreach ( $user->getKeys() as $key ) { + $data['keys'][] = $key->jsonSerialize(); + } + + $this->database->getDB( DB_PRIMARY )->replace( + 'oathauth_users', + 'id', + [ + 'id' => $userId, + 'module' => $user->getModule()->getName(), + 'data' => FormatJson::encode( $data ) + ], + __METHOD__ + ); + } $userName = $user->getUser()->getName(); $this->cache->set( $userName, $user ); @@ -178,12 +249,23 @@ class OATHUserRepository implements LoggerAwareInterface { * @param bool $self Whether they disabled it themselves */ public function remove( OATHUser $user, $clientInfo, bool $self ) { - $this->database->getDB( DB_PRIMARY )->delete( - 'oathauth_users', - [ 'id' => $this->centralIdLookupFactory->getLookup() - ->centralIdFromLocalUser( $user->getUser() ) ], - __METHOD__ - ); + $userId = $this->centralIdLookupFactory->getLookup() + ->centralIdFromLocalUser( $user->getUser() ); + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_WRITE_NEW ) { + $this->database->getDB( DB_PRIMARY )->delete( + 'oathauth_devices', + [ 'oad_user' => $userId ], + __METHOD__ + ); + } + + if ( $this->getMultipleDevicesMigrationStage() & SCHEMA_COMPAT_WRITE_OLD ) { + $this->database->getDB( DB_PRIMARY )->delete( + 'oathauth_users', + [ 'id' => $userId ], + __METHOD__ + ); + } $userName = $user->getUser()->getName(); $this->cache->delete( $userName ); @@ -195,4 +277,8 @@ class OATHUserRepository implements LoggerAwareInterface { ] ); Notifications\Manager::notifyDisabled( $user, $self ); } + + private function getMultipleDevicesMigrationStage(): int { + return $this->options->get( 'OATHAuthMultipleDevicesMigrationStage' ); + } } diff --git a/tests/phpunit/OATHAuthModuleRegistryTest.php b/tests/phpunit/OATHAuthModuleRegistryTest.php new file mode 100644 index 00000000..8cbbc8e4 --- /dev/null +++ b/tests/phpunit/OATHAuthModuleRegistryTest.php @@ -0,0 +1,59 @@ + + * @group Database + */ +class OATHAuthModuleRegistryTest extends MediaWikiIntegrationTestCase { + /** @var string[] */ + protected $tablesUsed = [ 'oathauth_types' ]; + + /** + * @covers \MediaWiki\Extension\OATHAuth\OATHAuthModuleRegistry::getModuleIds + */ + public function testGetModuleIds() { + $this->db->insert( + 'oathauth_types', + [ 'oat_name' => 'first' ], + __METHOD__ + ); + + $database = $this->createMock( OATHAuthDatabase::class ); + $database->method( 'getDB' )->willReturn( $this->db ); + + $registry = new OATHAuthModuleRegistry( + $database, + [ + 'first' => 'does not matter', + 'second' => 'does not matter', + 'third' => 'does not matter', + ] + ); + + $this->assertEquals( + [ 'first', 'second', 'third' ], + array_keys( $registry->getModuleIds() ) + ); + } +}