Prevent serialization errors from causing recursion in the Copycat transport

Change-Id: I0a1b0737d6cda3d7ab63bb26a7547d2f9124a434
This commit is contained in:
Jordan Halterman 2017-06-21 15:26:28 -07:00 committed by Ray Milkey
parent 83949a1bd8
commit b6ee9e966f
2 changed files with 31 additions and 9 deletions

View File

@ -303,6 +303,7 @@ public class NettyMessagingManager implements MessagingService {
try { try {
responsePayload = handler.apply(message.sender(), message.payload()); responsePayload = handler.apply(message.sender(), message.payload());
} catch (Exception e) { } catch (Exception e) {
log.debug("An error occurred in a message handler: {}", e);
status = Status.ERROR_HANDLER_EXCEPTION; status = Status.ERROR_HANDLER_EXCEPTION;
} }
sendReply(message, status, Optional.ofNullable(responsePayload)); sendReply(message, status, Optional.ofNullable(responsePayload));
@ -314,7 +315,13 @@ public class NettyMessagingManager implements MessagingService {
checkPermission(CLUSTER_WRITE); checkPermission(CLUSTER_WRITE);
handlers.put(type, message -> { handlers.put(type, message -> {
handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> { handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
Status status = error == null ? Status.OK : Status.ERROR_HANDLER_EXCEPTION; Status status;
if (error == null) {
status = Status.OK;
} else {
log.debug("An error occurred in a message handler: {}", error);
status = Status.ERROR_HANDLER_EXCEPTION;
}
sendReply(message, status, Optional.ofNullable(result)); sendReply(message, status, Optional.ofNullable(result));
}); });
}); });

View File

@ -56,6 +56,8 @@ import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
* Base Copycat Transport connection. * Base Copycat Transport connection.
*/ */
public class CopycatTransportConnection implements Connection { public class CopycatTransportConnection implements Connection {
private static final int MAX_MESSAGE_SIZE = 1024 * 1024;
private final Logger log = LoggerFactory.getLogger(getClass()); private final Logger log = LoggerFactory.getLogger(getClass());
private final long connectionId; private final long connectionId;
private final String localSubject; private final String localSubject;
@ -97,7 +99,11 @@ public class CopycatTransportConnection implements Connection {
((ReferenceCounted<?>) message).release(); ((ReferenceCounted<?>) message).release();
} }
messagingService.sendAsync(endpoint, remoteSubject, baos.toByteArray()) byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException(message + " exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
messagingService.sendAsync(endpoint, remoteSubject, bytes)
.whenComplete((r, e) -> { .whenComplete((r, e) -> {
if (e != null) { if (e != null) {
context.executor().execute(() -> future.completeExceptionally(e)); context.executor().execute(() -> future.completeExceptionally(e));
@ -122,9 +128,14 @@ public class CopycatTransportConnection implements Connection {
if (message instanceof ReferenceCounted) { if (message instanceof ReferenceCounted) {
((ReferenceCounted<?>) message).release(); ((ReferenceCounted<?>) message).release();
} }
byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException(message + " exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
messagingService.sendAndReceive(endpoint, messagingService.sendAndReceive(endpoint,
remoteSubject, remoteSubject,
baos.toByteArray(), bytes,
context.executor()) context.executor())
.whenComplete((response, error) -> handleResponse(response, error, future)); .whenComplete((response, error) -> handleResponse(response, error, future));
} catch (SerializationException | IOException e) { } catch (SerializationException | IOException e) {
@ -142,11 +153,11 @@ public class CopycatTransportConnection implements Connection {
CompletableFuture<T> future) { CompletableFuture<T> future) {
if (error != null) { if (error != null) {
Throwable rootCause = Throwables.getRootCause(error); Throwable rootCause = Throwables.getRootCause(error);
if (rootCause instanceof MessagingException || rootCause instanceof SocketException) { if (rootCause instanceof MessagingException.NoRemoteHandler) {
future.completeExceptionally(new TransportException(error));
close(rootCause);
} else if (rootCause instanceof SocketException) {
future.completeExceptionally(new TransportException(error)); future.completeExceptionally(new TransportException(error));
if (rootCause instanceof MessagingException.NoRemoteHandler) {
close(rootCause);
}
} else { } else {
future.completeExceptionally(error); future.completeExceptionally(error);
} }
@ -211,7 +222,11 @@ public class CopycatTransportConnection implements Connection {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
baos.write(error != null ? FAILURE : SUCCESS); baos.write(error != null ? FAILURE : SUCCESS);
context.serializer().writeObject(error != null ? error : result, baos); context.serializer().writeObject(error != null ? error : result, baos);
return baos.toByteArray(); byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException("response exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
return bytes;
} catch (IOException e) { } catch (IOException e) {
Throwables.propagate(e); Throwables.propagate(e);
return null; return null;
@ -278,7 +293,7 @@ public class CopycatTransportConnection implements Connection {
Throwable wrappedError = error; Throwable wrappedError = error;
if (error != null) { if (error != null) {
Throwable rootCause = Throwables.getRootCause(error); Throwable rootCause = Throwables.getRootCause(error);
if (MessagingException.class.isAssignableFrom(rootCause.getClass())) { if (rootCause instanceof MessagingException.NoRemoteHandler) {
wrappedError = new TransportException(error); wrappedError = new TransportException(error);
} }
future.completeExceptionally(wrappedError); future.completeExceptionally(wrappedError);