diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index 04344b18ce824..11fd20c973cf3 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -300,7 +300,7 @@ private void registerRequestHandler(DynamicActionRegistry dynamicActionRegistry) * Loads a single extension * @param extension The extension to be loaded */ - public void loadExtension(Extension extension) throws IOException { + public DiscoveryExtensionNode loadExtension(Extension extension) throws IOException { validateExtension(extension); DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode( extension.getName(), @@ -314,6 +314,12 @@ public void loadExtension(Extension extension) throws IOException { extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode); extensionSettingsMap.put(extension.getUniqueId(), extension); logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension); + return discoveryExtensionNode; + } + + public void initializeExtension(Extension extension) throws IOException { + DiscoveryExtensionNode node = loadExtension(extension); + initializeExtensionNode(node); } private void validateField(String fieldName, String value) throws IOException { @@ -340,13 +346,11 @@ private void validateExtension(Extension extension) throws IOException { */ public void initialize() { for (DiscoveryExtensionNode extension : extensionIdMap.values()) { - if (! initializedExtensions.containsKey(extension)) { - initializeExtension(extension); - } + initializeExtensionNode(extension); } } - private void initializeExtension(DiscoveryExtensionNode extension) { + private void initializeExtensionNode(DiscoveryExtensionNode extension) { final CompletableFuture inProgressFuture = new CompletableFuture<>(); final TransportResponseHandler initializeExtensionResponseHandler = new TransportResponseHandler< @@ -386,7 +390,7 @@ public String executor() { transportService.getThreadPool().generic().execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { - logger.warn(String.format("Error registering extension: %s", extension.getId()), e); + logger.warn("Error registering extension: " + extension.getId(), e); extensionIdMap.remove(extension.getId()); if (e.getCause() instanceof ConnectTransportException) { logger.info("No response from extension to request.", e); diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java index da0bc093f6b98..383796f0c3b44 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java @@ -63,7 +63,7 @@ public TransportResponse handleRegisterRestActionsRequest( ) throws Exception { DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId()); if (discoveryExtensionNode == null) { - throw new IllegalStateException(String.format("Missing extension node for %s", restActionsRequest.getUniqueId())); + throw new IllegalStateException("Missing extension node for " + restActionsRequest.getUniqueId()); } RestHandler handler = new RestSendToExtensionAction( restActionsRequest, diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestInitializeExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestInitializeExtensionAction.java index 4b622b841a040..fc7c21a6eccd6 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestInitializeExtensionAction.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestInitializeExtensionAction.java @@ -159,8 +159,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client extAdditionalSettings ); try { - extensionsManager.loadExtension(extension); - extensionsManager.initialize(); + extensionsManager.initializeExtension(extension); } catch (CompletionException e) { Throwable cause = e.getCause(); if (cause instanceof TimeoutException) { diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index d0cd127360031..c61afdd5c5261 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -36,6 +36,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.discovery.InitializeExtensionRequest; import org.opensearch.env.Environment; import org.opensearch.env.EnvironmentSettingsResponse; import org.opensearch.extensions.ExtensionsSettings.Extension; @@ -67,7 +68,6 @@ import java.util.HashMap; import java.util.List; import java.util.Set; -import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -78,6 +78,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -410,36 +411,37 @@ public void testInitialize() throws Exception { ) ); - // Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for - // now. + // Test needs to be changed to mock the connection between the local node and an extension. // Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045 // mockLogAppender.assertAllExpectationsMatched(); } } - public void testInitializeExtensionTwice() throws Exception { + public void testInitializeExtension() throws Exception { ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService); - initialize(extensionsManager); - - ThreadPool mockThreadPool = spy(threadPool); - ExecutorService mockExecutorService = mock(ExecutorService.class); - when(mockThreadPool.generic()).thenReturn(mockExecutorService); - TransportService transportService = new TransportService( - Settings.EMPTY, - mock(Transport.class), - mockThreadPool, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - x -> null, - null, - Collections.emptySet(), - NoopTracer.INSTANCE + TransportService mockTransportService = spy( + new TransportService( + Settings.EMPTY, + mock(Transport.class), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet(), + NoopTracer.INSTANCE + ) ); + doNothing().when(mockTransportService).connectToExtensionNode(any(DiscoveryExtensionNode.class)); + + doNothing().when(mockTransportService) + .sendRequest(any(DiscoveryExtensionNode.class), anyString(), any(InitializeExtensionRequest.class), any()); + extensionsManager.initializeServicesAndRestHandler( actionModule, settingsModule, - transportService, + mockTransportService, clusterService, settings, client, @@ -458,8 +460,7 @@ public void testInitializeExtensionTwice() throws Exception { null ); - extensionsManager.loadExtension(firstExtension); - extensionsManager.initialize(); + extensionsManager.initializeExtension(firstExtension); Extension secondExtension = new Extension( "secondExtension", @@ -473,12 +474,18 @@ public void testInitializeExtensionTwice() throws Exception { null ); - extensionsManager.loadExtension(secondExtension); - extensionsManager.initialize(); + extensionsManager.initializeExtension(secondExtension); - // When execution is mocked, the successful registration callback is not called and the extension is never added to - // registered extensions. - // verify(mockExecutorService, times(2)).execute(any()); + ThreadPool.terminate(threadPool, 3, TimeUnit.SECONDS); + + verify(mockTransportService, times(2)).connectToExtensionNode(any(DiscoveryExtensionNode.class)); + + verify(mockTransportService, times(2)).sendRequest( + any(DiscoveryExtensionNode.class), + anyString(), + any(InitializeExtensionRequest.class), + any() + ); } public void testHandleRegisterRestActionsRequest() throws Exception { @@ -515,7 +522,7 @@ public void testHandleRegisterRestActionsRequestRequiresDiscoveryNode() throws E ); } - public void testHandleRegisterTwoRestActionsRequest() throws Exception { + public void testHandleRegisterRestActionsRequestMultiple() throws Exception { ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService); initialize(extensionsManager); @@ -523,12 +530,12 @@ public void testHandleRegisterTwoRestActionsRequest() throws Exception { List actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz"); List deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!"); for (int i = 0; i < 2; i++) { - String uniqueIdStr = String.format("uniqueid-%d", i); + String uniqueIdStr = "uniqueid-%d" + i; Set> additionalSettings = extAwarePlugin.getExtensionSettings().stream().collect(Collectors.toSet()); ExtensionScopedSettings extensionScopedSettings = new ExtensionScopedSettings(additionalSettings); Extension firstExtension = new Extension( - String.format("Extension %s", i), + "Extension %s" + i, uniqueIdStr, "127.0.0.0", "9300",