diff --git a/CHANGELOG.md b/CHANGELOG.md index 00d30dc67e..5deb883e98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ When updating the changelog, remember to be very clear about what behavior has c and what APIs have changed, if applicable. ## [Unreleased] +- Fix TimingKey Memory Leak ## [29.17.2] - 2021-04-11 - Fix the default value resolution logic in Avro schema translator to match the PDL behavior. diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java index 8b7864808a..bc9467650e 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java @@ -157,6 +157,7 @@ public void shutdown(final Callback callback) callback.onSuccess(None.none()); }); + TimingKey.unregisterKey(TIMING_KEY); } @Override diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java index bd6273fb89..f2e9f142c3 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java @@ -23,7 +23,7 @@ import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamResponse; - +import java.util.List; import java.util.Map; /** @@ -189,4 +189,14 @@ void onStreamResponse(StreamResponse res, void onStreamError(Exception ex, RequestContext requestContext, Map wireAttrs); + + /** + * Returns a copy of a list of RestFilters + */ + List getRestFilters(); + + /** + * Returns a copy of a list of StreamFilters + */ + List getStreamFilters(); } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java index dee5ff2ced..725f3193f7 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java @@ -78,6 +78,16 @@ public FilterChain addLast(StreamFilter filter) return new FilterChainImpl(_restFilters, doAddLast(_streamFilters, decorateStreamFilter(filter))); } + @Override + public List getRestFilters() { + return new ArrayList(_restFilters); + } + + @Override + public List getStreamFilters() { + return new ArrayList(_streamFilters); + } + private RestFilter decorateRestFilter(RestFilter filter) { return new TimedRestFilter(filter); diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java b/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java index 0261b5a0f8..1a39272aa9 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java @@ -23,6 +23,8 @@ import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.timing.TimingImportance; +import java.util.Arrays; +import java.util.List; import java.util.Map; @@ -31,7 +33,7 @@ * * @author Xialin Zhu */ -/* package private */ class TimedRestFilter implements RestFilter +public class TimedRestFilter implements RestFilter { protected static final String ON_REQUEST_SUFFIX = "onRequest"; protected static final String ON_RESPONSE_SUFFIX = "onResponse"; @@ -91,4 +93,8 @@ public void onRestError(Throwable ex, TimingContextUtil.markTiming(requestContext, _onErrorTimingKey); _restFilter.onRestError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter)); } + + public List getTimingKeyList() { + return Arrays.asList(_onErrorTimingKey, _onRequestTimingKey, _onResponseTimingKey); + } } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java b/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java index c1759decac..b661bef842 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java @@ -23,6 +23,8 @@ import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.message.timing.TimingImportance; +import java.util.Arrays; +import java.util.List; import java.util.Map; import static com.linkedin.r2.filter.TimedRestFilter.ON_ERROR_SUFFIX; @@ -34,7 +36,7 @@ * * @author Xialin Zhu */ -/* package private */ class TimedStreamFilter implements StreamFilter +public class TimedStreamFilter implements StreamFilter { private final StreamFilter _streamFilter; private final TimingKey _onRequestTimingKey; @@ -91,4 +93,8 @@ public void onStreamError(Throwable ex, TimingContextUtil.markTiming(requestContext, _onErrorTimingKey); _streamFilter.onStreamError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter)); } + + public List getTimingKeyList() { + return Arrays.asList(_onErrorTimingKey, _onRequestTimingKey, _onResponseTimingKey); + } } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java index 11e4e15840..685f483756 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java @@ -17,10 +17,13 @@ /* $Id$ */ package com.linkedin.r2.filter.transport; - import com.linkedin.common.callback.Callback; import com.linkedin.common.util.None; import com.linkedin.r2.filter.FilterChain; +import com.linkedin.r2.filter.TimedRestFilter; +import com.linkedin.r2.filter.TimedStreamFilter; +import com.linkedin.r2.filter.message.rest.RestFilter; +import com.linkedin.r2.filter.message.stream.StreamFilter; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.Response; import com.linkedin.r2.message.rest.RestRequest; @@ -29,12 +32,15 @@ import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.message.timing.FrameworkTimingKeys; import com.linkedin.r2.message.timing.TimingContextUtil; +import com.linkedin.r2.message.timing.TimingKey; import com.linkedin.r2.transport.common.bridge.client.TransportClient; import com.linkedin.r2.transport.common.bridge.common.TransportCallback; - import com.linkedin.r2.transport.common.bridge.common.TransportResponse; +import java.util.Collection; +import java.util.List; import java.util.Map; + /** * {@link TransportClient} adapter which composes a {@link TransportClient} * and a {@link FilterChain}. @@ -46,6 +52,7 @@ public class FilterChainClient implements TransportClient { private final TransportClient _client; private final FilterChain _filters; + private final FilterChain _sharedFilters; /** * Construct a new instance by composing the specified {@link TransportClient} @@ -53,8 +60,9 @@ public class FilterChainClient implements TransportClient * * @param client the {@link TransportClient} to be composed. * @param filters the {@link FilterChain} to be composed. + * @param sharedFilters the {@link FilterChain} can be used by other clients. */ - public FilterChainClient(TransportClient client, FilterChain filters) + public FilterChainClient(TransportClient client, FilterChain filters, FilterChain sharedFilters) { _client = client; @@ -66,6 +74,7 @@ public FilterChainClient(TransportClient client, FilterChain filters) .addLastRest(requestFilter) .addFirst(responseFilter) .addLast(requestFilter); + _sharedFilters = sharedFilters; } @Override @@ -94,6 +103,27 @@ public void streamRequest(StreamRequest request, public void shutdown(Callback callback) { _client.shutdown(callback); + + List streamFilters = _filters.getStreamFilters(); + List restFilters = _filters.getRestFilters(); + List sharedStreamFilters = _sharedFilters.getStreamFilters(); + List sharedRestFilters = _sharedFilters.getRestFilters(); + + streamFilters.stream() + .filter(filter -> !sharedStreamFilters.contains(filter)) + .filter(TimedStreamFilter.class::isInstance) + .map(TimedStreamFilter.class::cast) + .map(TimedStreamFilter::getTimingKeyList) + .flatMap(Collection::stream) + .forEach(TimingKey::unregisterKey); + + restFilters.stream() + .filter(filter -> !sharedRestFilters.contains(filter)) + .filter(TimedRestFilter.class::isInstance) + .map(TimedRestFilter.class::cast) + .map(TimedRestFilter::getTimingKeyList) + .flatMap(Collection::stream) + .forEach(TimingKey::unregisterKey); } /** diff --git a/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java b/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java index 8d551af59f..7012121e83 100644 --- a/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java +++ b/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java @@ -17,9 +17,12 @@ package com.linkedin.r2.message.timing; import java.util.Map; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import com.linkedin.r2.message.RequestContext; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** @@ -32,6 +35,7 @@ public class TimingKey { private static final Map _pool = new ConcurrentHashMap<>(); + private static final ExecutorService _unregisterExecutor = Executors.newFixedThreadPool(1); private final String _name; private final String _type; @@ -130,4 +134,26 @@ public static TimingKey registerNewKey(String uniqueName, String type, TimingImp { return registerNewKey(new TimingKey(uniqueName, type, timingImportance)); } + + /** + * Unregister a TimingKey to reclaim the memory + * + */ + public static void unregisterKey(TimingKey key) + { + _unregisterExecutor.submit(new Callable() { + public Void call() throws Exception { + _pool.remove(key.getName()); + return null; + } + }); + } + + /** + * Return how many registered keys, for testing purpose. + */ + public static int getCount() { + return _pool.size(); + } + } diff --git a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java index d458a2f1a1..e9d0ac9fac 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java @@ -246,6 +246,7 @@ public void onSuccess(None result) { callback.onError(new IllegalStateException("Shutdown has already been requested.")); } + TimingKey.unregisterKey(TIMING_KEY); } private void sendStreamRequestAsRestRequest(StreamRequest request, RequestContext requestContext, diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java index 2be59979f3..6ed7b0b372 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java @@ -958,7 +958,6 @@ public TransportClient getClient(Map properties) properties = new HashMap(properties); sslContext = coerceAndRemoveFromMap(HTTP_SSL_CONTEXT, properties, SSLContext.class); sslParameters = coerceAndRemoveFromMap(HTTP_SSL_PARAMS, properties, SSLParameters.class); - return getClient(properties, sslContext, sslParameters); } @@ -1121,7 +1120,7 @@ private TransportClient getClient(Map properties, filters = filters.addLastRest(disruptFilter); filters = filters.addLast(disruptFilter); - client = new FilterChainClient(client, filters); + client = new FilterChainClient(client, filters, _filters); client = new FactoryClient(client); synchronized (_mutex) { diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java index d08dbc1e8a..4db1b81ee9 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java @@ -284,6 +284,7 @@ public void onSuccess(None result) _shutdownTimeout); _jmxManager.onProviderShutdown(_channelPoolManager); _jmxManager.onProviderShutdown(_sslChannelPoolManager); + TimingKey.unregisterKey(TIMING_KEY); } else { diff --git a/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java b/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java index 136b599c9d..9ca82c882b 100644 --- a/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java +++ b/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java @@ -27,6 +27,7 @@ import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestRequestBuilder; import com.linkedin.r2.message.rest.RestResponse; +import com.linkedin.r2.message.timing.TimingKey; import com.linkedin.r2.testutils.server.HttpServerBuilder; import com.linkedin.r2.transport.common.Client; import com.linkedin.r2.transport.common.bridge.client.TransportClient; @@ -93,13 +94,17 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion { server.start(); List clients = new ArrayList<>(); + + int savedTimingKeyCount = TimingKey.getCount(); for (int i = 0; i < 100; i++) { HashMap properties = new HashMap<>(); properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocolVersion); clients.add(new TransportClientAdapter(factory.getClient(properties), restOverStream)); } - + int addedTimingKeyCount = TimingKey.getCount() - savedTimingKeyCount; + // In current implementation, one client can have around 30 TimingKeys by default. + Assert.assertTrue(addedTimingKeyCount >= 30 * clients.size()); for (Client c : clients) { RestRequest r = new RestRequestBuilder(new URI(URI)).build(); @@ -107,6 +112,7 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion } Assert.assertEquals(httpServerStatsProvider.requestCount(), expectedRequests); + savedTimingKeyCount = TimingKey.getCount(); for (Client c : clients) { FutureCallback callback = new FutureCallback<>(); @@ -117,6 +123,8 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion FutureCallback factoryShutdown = new FutureCallback<>(); factory.shutdown(factoryShutdown); factoryShutdown.get(30, TimeUnit.SECONDS); + int removedTimingKeyCount = savedTimingKeyCount - TimingKey.getCount(); + Assert.assertEquals(addedTimingKeyCount, removedTimingKeyCount); } finally {