diff --git a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java index 3c6acf52b7..2947b5f792 100644 --- a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java +++ b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java @@ -20,6 +20,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; +import org.onosproject.bmv2.api.runtime.Bmv2Client; import org.onosproject.bmv2.api.runtime.Bmv2MatchKey; import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException; import org.onosproject.bmv2.api.runtime.Bmv2TableEntry; @@ -84,7 +85,7 @@ public class Bmv2FlowRuleDriver extends AbstractHandlerBehaviour DeviceId deviceId = handler().data().deviceId(); - Bmv2ThriftClient deviceClient; + Bmv2Client deviceClient; try { deviceClient = Bmv2ThriftClient.of(deviceId); } catch (Bmv2RuntimeException e) { diff --git a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java index 927b77d3e2..e1da42bd29 100644 --- a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java +++ b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java @@ -18,6 +18,7 @@ package org.onosproject.drivers.bmv2; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import org.onosproject.bmv2.api.runtime.Bmv2Client; import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException; import org.onosproject.bmv2.ctl.Bmv2ThriftClient; import org.onosproject.net.DefaultAnnotations; @@ -41,7 +42,7 @@ public class Bmv2PortGetterDriver extends AbstractHandlerBehaviour @Override public List getPorts() { - Bmv2ThriftClient deviceClient; + Bmv2Client deviceClient; try { deviceClient = Bmv2ThriftClient.of(handler().data().deviceId()); } catch (Bmv2RuntimeException e) { diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java new file mode 100644 index 0000000000..252f59bd18 --- /dev/null +++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java @@ -0,0 +1,89 @@ +/* + * Copyright 2016-present Open Networking Laboratory + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.onosproject.bmv2.api.runtime; + +import java.util.Collection; + +/** + * RPC client to control a BMv2 device. + */ +public interface Bmv2Client { + /** + * Adds a new table entry. + * + * @param entry a table entry value + * @return table-specific entry ID + * @throws Bmv2RuntimeException if any error occurs + */ + long addTableEntry(Bmv2TableEntry entry) throws Bmv2RuntimeException; + + /** + * Modifies a currently installed entry by updating its action. + * + * @param tableName string value of table name + * @param entryId long value of entry ID + * @param action an action value + * @throws Bmv2RuntimeException if any error occurs + */ + void modifyTableEntry(String tableName, + long entryId, Bmv2Action action) + throws Bmv2RuntimeException; + + /** + * Deletes currently installed entry. + * + * @param tableName string value of table name + * @param entryId long value of entry ID + * @throws Bmv2RuntimeException if any error occurs + */ + void deleteTableEntry(String tableName, + long entryId) throws Bmv2RuntimeException; + + /** + * Sets table default action. + * + * @param tableName string value of table name + * @param action an action value + * @throws Bmv2RuntimeException if any error occurs + */ + void setTableDefaultAction(String tableName, Bmv2Action action) + throws Bmv2RuntimeException; + + /** + * Returns information of the ports currently configured in the switch. + * + * @return collection of port information + * @throws Bmv2RuntimeException if any error occurs + */ + Collection getPortsInfo() throws Bmv2RuntimeException; + + /** + * Return a string representation of a table content. + * + * @param tableName string value of table name + * @return table string dump + * @throws Bmv2RuntimeException if any error occurs + */ + String dumpTable(String tableName) throws Bmv2RuntimeException; + + /** + * Reset the state of the switch (e.g. delete all entries, etc.). + * + * @throws Bmv2RuntimeException if any error occurs + */ + void resetState() throws Bmv2RuntimeException; +} diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java index eb6687ad0c..f1a86fcda0 100644 --- a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java +++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java @@ -31,13 +31,13 @@ import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; -import org.onlab.util.ImmutableByteSequence; import org.onosproject.bmv2.api.runtime.Bmv2Action; +import org.onosproject.bmv2.api.runtime.Bmv2Client; import org.onosproject.bmv2.api.runtime.Bmv2ExactMatchParam; -import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException; import org.onosproject.bmv2.api.runtime.Bmv2LpmMatchParam; import org.onosproject.bmv2.api.runtime.Bmv2MatchKey; import org.onosproject.bmv2.api.runtime.Bmv2PortInfo; +import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException; import org.onosproject.bmv2.api.runtime.Bmv2TableEntry; import org.onosproject.bmv2.api.runtime.Bmv2TernaryMatchParam; import org.onosproject.bmv2.api.runtime.Bmv2ValidMatchParam; @@ -51,6 +51,8 @@ import org.p4.bmv2.thrift.BmMatchParamType; import org.p4.bmv2.thrift.BmMatchParamValid; import org.p4.bmv2.thrift.DevMgrPortInfo; import org.p4.bmv2.thrift.Standard; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; import java.util.Collection; @@ -60,38 +62,43 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkNotNull; +import static org.onosproject.bmv2.ctl.SafeThriftClient.Options; /** * Implementation of a Thrift client to control the Bmv2 switch. */ -public final class Bmv2ThriftClient { - /* - FIXME: derive context_id from device id - Using different context id values should serve to control different - switches responding to the same IP address and port - */ +public final class Bmv2ThriftClient implements Bmv2Client { + + private static final Logger LOG = + LoggerFactory.getLogger(Bmv2ThriftClient.class); + + // FIXME: make context_id arbitrary for each call + // See: https://github.com/p4lang/behavioral-model/blob/master/modules/bm_sim/include/bm_sim/context.h private static final int CONTEXT_ID = 0; - /* - Static transport/client cache: - - avoids opening a new transport session when there's one already open - - close the connection after a predefined timeout of 5 seconds - */ - private static LoadingCache - clientCache = CacheBuilder.newBuilder() - .expireAfterAccess(5, TimeUnit.SECONDS) + // Seconds after a client is expired (and connection closed) in the cache. + private static final int CLIENT_CACHE_TIMEOUT = 60; + // Number of connection retries after a network error. + private static final int NUM_CONNECTION_RETRIES = 10; + // Time between retries in milliseconds. + private static final int TIME_BETWEEN_RETRIES = 200; + + // Static client cache where clients are removed after a predefined timeout. + private static final LoadingCache + CLIENT_CACHE = CacheBuilder.newBuilder() + .expireAfterAccess(CLIENT_CACHE_TIMEOUT, TimeUnit.SECONDS) .removalListener(new ClientRemovalListener()) .build(new ClientLoader()); private final Standard.Iface stdClient; private final TTransport transport; + private final DeviceId deviceId; // ban constructor - private Bmv2ThriftClient(TTransport transport, Standard.Iface stdClient) { + private Bmv2ThriftClient(DeviceId deviceId, TTransport transport, Standard.Iface stdClient) { + this.deviceId = deviceId; this.transport = transport; this.stdClient = stdClient; - } - private void closeTransport() { - this.transport.close(); + LOG.debug("New client created! > deviceId={}", deviceId); } /** @@ -104,8 +111,10 @@ public final class Bmv2ThriftClient { public static Bmv2ThriftClient of(DeviceId deviceId) throws Bmv2RuntimeException { try { checkNotNull(deviceId, "deviceId cannot be null"); - return clientCache.get(deviceId); + LOG.debug("Getting a client from cache... > deviceId{}", deviceId); + return CLIENT_CACHE.get(deviceId); } catch (ExecutionException e) { + LOG.debug("Exception while getting a client from cache: {} > ", e, deviceId); throw new Bmv2RuntimeException(e.getMessage(), e.getCause()); } } @@ -120,9 +129,13 @@ public final class Bmv2ThriftClient { public static boolean ping(DeviceId deviceId) { // poll ports status as workaround to assess device reachability try { - of(deviceId).stdClient.bm_dev_mgr_show_ports(); + LOG.debug("Pinging device... > deviceId={}", deviceId); + Bmv2ThriftClient client = of(deviceId); + client.stdClient.bm_dev_mgr_show_ports(); + LOG.debug("Device reachable! > deviceId={}", deviceId); return true; } catch (TException | Bmv2RuntimeException e) { + LOG.debug("Device NOT reachable! > deviceId={}", deviceId); return false; } } @@ -156,32 +169,34 @@ public final class Bmv2ThriftClient { private static List buildMatchParamsList(Bmv2MatchKey matchKey) { List paramsList = Lists.newArrayList(); matchKey.matchParams().forEach(x -> { + ByteBuffer value; + ByteBuffer mask; switch (x.type()) { case EXACT: + value = ByteBuffer.wrap(((Bmv2ExactMatchParam) x).value().asArray()); paramsList.add( new BmMatchParam(BmMatchParamType.EXACT) - .setExact(new BmMatchParamExact( - ((Bmv2ExactMatchParam) x).value().asReadOnlyBuffer()))); + .setExact(new BmMatchParamExact(value))); break; case TERNARY: + value = ByteBuffer.wrap(((Bmv2TernaryMatchParam) x).value().asArray()); + mask = ByteBuffer.wrap(((Bmv2TernaryMatchParam) x).mask().asArray()); paramsList.add( new BmMatchParam(BmMatchParamType.TERNARY) - .setTernary(new BmMatchParamTernary( - ((Bmv2TernaryMatchParam) x).value().asReadOnlyBuffer(), - ((Bmv2TernaryMatchParam) x).mask().asReadOnlyBuffer()))); + .setTernary(new BmMatchParamTernary(value, mask))); break; case LPM: + value = ByteBuffer.wrap(((Bmv2LpmMatchParam) x).value().asArray()); + int prefixLength = ((Bmv2LpmMatchParam) x).prefixLength(); paramsList.add( new BmMatchParam(BmMatchParamType.LPM) - .setLpm(new BmMatchParamLPM( - ((Bmv2LpmMatchParam) x).value().asReadOnlyBuffer(), - ((Bmv2LpmMatchParam) x).prefixLength()))); + .setLpm(new BmMatchParamLPM(value, prefixLength))); break; case VALID: + boolean flag = ((Bmv2ValidMatchParam) x).flag(); paramsList.add( new BmMatchParam(BmMatchParamType.VALID) - .setValid(new BmMatchParamValid( - ((Bmv2ValidMatchParam) x).flag()))); + .setValid(new BmMatchParamValid(flag))); break; default: // should never be here @@ -198,21 +213,26 @@ public final class Bmv2ThriftClient { * @return list of ByteBuffers */ private static List buildActionParamsList(Bmv2Action action) { - return action.parameters() - .stream() - .map(ImmutableByteSequence::asReadOnlyBuffer) - .collect(Collectors.toList()); + List buffers = Lists.newArrayList(); + action.parameters().forEach(p -> buffers.add(ByteBuffer.wrap(p.asArray()))); + return buffers; } - /** - * Adds a new table entry. - * - * @param entry a table entry value - * @return table-specific entry ID - * @throws Bmv2RuntimeException if any error occurs - */ + private void closeTransport() { + LOG.debug("Closing transport session... > deviceId={}", deviceId); + if (this.transport.isOpen()) { + this.transport.close(); + LOG.debug("Transport session closed! > deviceId={}", deviceId); + } else { + LOG.debug("Transport session was already closed! deviceId={}", deviceId); + } + } + + @Override public final long addTableEntry(Bmv2TableEntry entry) throws Bmv2RuntimeException { + LOG.debug("Adding table entry... > deviceId={}, entry={}", deviceId, entry); + long entryId = -1; try { @@ -237,34 +257,33 @@ public final class Bmv2ThriftClient { CONTEXT_ID, entry.tableName(), entryId, msTimeout); } + LOG.debug("Table entry added! > deviceId={}, entryId={}/{}", deviceId, entry.tableName(), entryId); + return entryId; } catch (TException e) { + LOG.debug("Exception while adding table entry: {} > deviceId={}, tableName={}", + e, deviceId, entry.tableName()); if (entryId != -1) { + // entry is in inconsistent state (unable to add timeout), remove it try { - stdClient.bm_mt_delete_entry( - CONTEXT_ID, entry.tableName(), entryId); - } catch (TException e1) { - // this should never happen as we know the entry is there - throw new Bmv2RuntimeException(e1.getMessage(), e1); + deleteTableEntry(entry.tableName(), entryId); + } catch (Bmv2RuntimeException e1) { + LOG.debug("Unable to remove failed table entry: {} > deviceId={}, tableName={}", + e1, deviceId, entry.tableName()); } } throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Modifies a currently installed entry by updating its action. - * - * @param tableName string value of table name - * @param entryId long value of entry ID - * @param action an action value - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public final void modifyTableEntry(String tableName, long entryId, Bmv2Action action) throws Bmv2RuntimeException { + LOG.debug("Modifying table entry... > deviceId={}, entryId={}/{}", deviceId, tableName, entryId); + try { stdClient.bm_mt_modify_entry( CONTEXT_ID, @@ -272,57 +291,55 @@ public final class Bmv2ThriftClient { entryId, action.name(), buildActionParamsList(action)); + LOG.debug("Table entry modified! > deviceId={}, entryId={}/{}", deviceId, tableName, entryId); } catch (TException e) { + LOG.debug("Exception while modifying table entry: {} > deviceId={}, entryId={}/{}", + e, deviceId, tableName, entryId); throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Deletes currently installed entry. - * - * @param tableName string value of table name - * @param entryId long value of entry ID - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public final void deleteTableEntry(String tableName, long entryId) throws Bmv2RuntimeException { + LOG.debug("Deleting table entry... > deviceId={}, entryId={}/{}", deviceId, tableName, entryId); + try { stdClient.bm_mt_delete_entry(CONTEXT_ID, tableName, entryId); + LOG.debug("Table entry deleted! > deviceId={}, entryId={}/{}", deviceId, tableName, entryId); } catch (TException e) { + LOG.debug("Exception while deleting table entry: {} > deviceId={}, entryId={}/{}", + e, deviceId, tableName, entryId); throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Sets table default action. - * - * @param tableName string value of table name - * @param action an action value - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public final void setTableDefaultAction(String tableName, Bmv2Action action) throws Bmv2RuntimeException { + LOG.debug("Setting table default... > deviceId={}, tableName={}, action={}", deviceId, tableName, action); + try { stdClient.bm_mt_set_default_action( CONTEXT_ID, tableName, action.name(), buildActionParamsList(action)); + LOG.debug("Table default set! > deviceId={}, tableName={}, action={}", deviceId, tableName, action); } catch (TException e) { + LOG.debug("Exception while setting table default : {} > deviceId={}, tableName={}, action={}", + e, deviceId, tableName, action); throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Returns information of the ports currently configured in the switch. - * - * @return collection of port information - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public Collection getPortsInfo() throws Bmv2RuntimeException { + LOG.debug("Retrieving port info... > deviceId={}", deviceId); + try { List portInfos = stdClient.bm_dev_mgr_show_ports(); @@ -333,39 +350,42 @@ public final class Bmv2ThriftClient { .map(Bmv2PortInfo::new) .collect(Collectors.toList())); + LOG.debug("Port info retrieved! > deviceId={}, portInfos={}", deviceId, bmv2PortInfos); + return bmv2PortInfos; } catch (TException e) { + LOG.debug("Exception while retrieving port info: {} > deviceId={}", e, deviceId); throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Return a string representation of a table content. - * - * @param tableName string value of table name - * @return table string dump - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public String dumpTable(String tableName) throws Bmv2RuntimeException { + LOG.debug("Retrieving table dump... > deviceId={}, tableName={}", deviceId, tableName); + try { - return stdClient.bm_dump_table(CONTEXT_ID, tableName); + String dump = stdClient.bm_dump_table(CONTEXT_ID, tableName); + LOG.debug("Table dump retrieved! > deviceId={}, tableName={}", deviceId, tableName); + return dump; } catch (TException e) { + LOG.debug("Exception while retrieving table dump: {} > deviceId={}, tableName={}", + e, deviceId, tableName); throw new Bmv2RuntimeException(e.getMessage(), e); } } - /** - * Reset the state of the switch (e.g. delete all entries, etc.). - * - * @throws Bmv2RuntimeException if any error occurs - */ + @Override public void resetState() throws Bmv2RuntimeException { + LOG.debug("Resetting device state... > deviceId={}", deviceId); + try { stdClient.bm_reset_state(); + LOG.debug("Device state reset! > deviceId={}", deviceId); } catch (TException e) { + LOG.debug("Exception while resetting device state: {} > deviceId={}", e, deviceId); throw new Bmv2RuntimeException(e.getMessage(), e); } } @@ -376,20 +396,26 @@ public final class Bmv2ThriftClient { private static class ClientLoader extends CacheLoader { + // Connection retries options: max 10 retries each 200 ms + private static final Options RECONN_OPTIONS = new Options(NUM_CONNECTION_RETRIES, TIME_BETWEEN_RETRIES); + @Override public Bmv2ThriftClient load(DeviceId deviceId) throws TTransportException { + LOG.debug("Creating new client in cache... > deviceId={}", deviceId); Pair info = parseDeviceId(deviceId); //make the expensive call TTransport transport = new TSocket( info.getLeft(), info.getRight()); TProtocol protocol = new TBinaryProtocol(transport); - Standard.Iface stdClient = new Standard.Client( + Standard.Client stdClient = new Standard.Client( new TMultiplexedProtocol(protocol, "standard")); + // Wrap the client so to automatically have synchronization and resiliency to connectivity problems + Standard.Iface reconnStdIface = SafeThriftClient.wrap(stdClient, + Standard.Iface.class, + RECONN_OPTIONS); - transport.open(); - - return new Bmv2ThriftClient(transport, stdClient); + return new Bmv2ThriftClient(deviceId, transport, reconnStdIface); } } @@ -403,6 +429,7 @@ public final class Bmv2ThriftClient { public void onRemoval( RemovalNotification notification) { // close the transport connection + LOG.debug("Removing client from cache... > deviceId={}", notification.getKey()); notification.getValue().closeTransport(); } } diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java new file mode 100644 index 0000000000..bbe0546a08 --- /dev/null +++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java @@ -0,0 +1,247 @@ +/* + * Copyright 2016-present Open Networking Laboratory + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Most of the code of this class was copied from: + * http://liveramp.com/engineering/reconnecting-thrift-client/ + */ + +package org.onosproject.bmv2.ctl; + +import com.google.common.collect.ImmutableSet; +import org.apache.thrift.TServiceClient; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.Set; + +/** + * Thrift client wrapper that attempts a few reconnects before giving up a method call execution. It al provides + * synchronization between calls (automatically serialize multiple calls to the same client from different threads). + */ +public final class SafeThriftClient { + + private static final Logger LOG = LoggerFactory.getLogger(SafeThriftClient.class); + + /** + * List of causes which suggest a restart might fix things (defined as constants in {@link TTransportException}). + */ + private static final Set RESTARTABLE_CAUSES = ImmutableSet.of(TTransportException.NOT_OPEN, + TTransportException.END_OF_FILE, + TTransportException.TIMED_OUT, + TTransportException.UNKNOWN); + + private SafeThriftClient() { + // ban constructor. + } + + /** + * Reflectively wraps an already existing Thrift client. + * + * @param baseClient the client to wrap + * @param clientInterface the interface that the client implements + * @param options options that control behavior of the reconnecting client + * @param + * @param + * @return + */ + public static C wrap(T baseClient, Class clientInterface, Options options) { + Object proxyObject = Proxy.newProxyInstance(clientInterface.getClassLoader(), + new Class[]{clientInterface}, + new ReconnectingClientProxy(baseClient, + options.getNumRetries(), + options.getTimeBetweenRetries())); + + return (C) proxyObject; + } + + /** + * Reflectively wraps an already existing Thrift client. + * + * @param baseClient the client to wrap + * @param options options that control behavior of the reconnecting client + * @param + * @param + * @return + */ + public static C wrap(T baseClient, Options options) { + Class[] interfaces = baseClient.getClass().getInterfaces(); + + for (Class iface : interfaces) { + if (iface.getSimpleName().equals("Iface") + && iface.getEnclosingClass().equals(baseClient.getClass().getEnclosingClass())) { + return (C) wrap(baseClient, iface, options); + } + } + + throw new RuntimeException("Class needs to implement Iface directly. Use wrap(TServiceClient, Class) instead."); + } + + /** + * Reflectively wraps an already existing Thrift client. + * + * @param baseClient the client to wrap + * @param clientInterface the interface that the client implements + * @param + * @param + * @return + */ + public static C wrap(T baseClient, Class clientInterface) { + return wrap(baseClient, clientInterface, Options.defaults()); + } + + /** + * Reflectively wraps an already existing Thrift client. + * + * @param baseClient the client to wrap + * @param + * @param + * @return + */ + public static C wrap(T baseClient) { + return wrap(baseClient, Options.defaults()); + } + + /** + * Reconnection options for {@link SafeThriftClient}. + */ + public static class Options { + private int numRetries; + private long timeBetweenRetries; + + /** + * Creates new options with the given parameters. + * + * @param numRetries the maximum number of times to try reconnecting before giving up and throwing an + * exception + * @param timeBetweenRetries the number of milliseconds to wait in between reconnection attempts. + */ + public Options(int numRetries, long timeBetweenRetries) { + this.numRetries = numRetries; + this.timeBetweenRetries = timeBetweenRetries; + } + + private static Options defaults() { + return new Options(5, 10000L); + } + + private int getNumRetries() { + return numRetries; + } + + private long getTimeBetweenRetries() { + return timeBetweenRetries; + } + } + + /** + * Helper proxy class. Attempts to call method on proxy object wrapped in try/catch. If it fails, it attempts a + * reconnect and tries the method again. + * + * @param + */ + private static class ReconnectingClientProxy implements InvocationHandler { + private final T baseClient; + private final int maxRetries; + private final long timeBetweenRetries; + + public ReconnectingClientProxy(T baseClient, int maxRetries, long timeBetweenRetries) { + this.baseClient = baseClient; + this.maxRetries = maxRetries; + this.timeBetweenRetries = timeBetweenRetries; + } + + private static void reconnectOrThrowException(TTransport transport, int maxRetries, long timeBetweenRetries) + throws TTransportException { + int errors = 0; + transport.close(); + + while (errors < maxRetries) { + try { + LOG.debug("Attempting to reconnect..."); + transport.open(); + LOG.debug("Reconnection successful"); + break; + } catch (TTransportException e) { + LOG.error("Error while reconnecting:", e); + errors++; + + if (errors < maxRetries) { + try { + LOG.debug("Sleeping for {} milliseconds before retrying", timeBetweenRetries); + Thread.sleep(timeBetweenRetries); + } catch (InterruptedException e2) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + } + } + + if (errors >= maxRetries) { + throw new TTransportException("Failed to reconnect"); + } + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + + // With Thrift clients must be instantiated for each different transport session, i.e. server instance. + // Hence, using baseClient as lock, only calls towards the same server will be synchronized. + + synchronized (baseClient) { + + LOG.debug("Invoking client method... > method={}, fromThread={}", + method.getName(), Thread.currentThread().getId()); + + Object result = null; + + try { + result = method.invoke(baseClient, args); + + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof TTransportException) { + TTransportException cause = (TTransportException) e.getTargetException(); + + if (RESTARTABLE_CAUSES.contains(cause.getType())) { + reconnectOrThrowException(baseClient.getInputProtocol().getTransport(), + maxRetries, + timeBetweenRetries); + result = method.invoke(baseClient, args); + } + } + + if (result == null) { + LOG.debug("Exception while invoking client method: {} > method={}, fromThread={}", + e, method.getName(), Thread.currentThread().getId()); + throw e.getTargetException(); + } + } + + LOG.debug("Method invoke complete! > method={}, fromThread={}", + method.getName(), Thread.currentThread().getId()); + + return result; + } + } + } +}