46
46
import io .grpc .InternalServerInterceptors ;
47
47
import io .grpc .Metadata ;
48
48
import io .grpc .ServerCall ;
49
+ import io .grpc .ServerCallExecutorSupplier ;
49
50
import io .grpc .ServerCallHandler ;
50
51
import io .grpc .ServerInterceptor ;
51
52
import io .grpc .ServerMethodDefinition ;
52
53
import io .grpc .ServerServiceDefinition ;
53
54
import io .grpc .ServerTransportFilter ;
54
55
import io .grpc .Status ;
56
+ import io .grpc .StatusException ;
55
57
import io .perfmark .Link ;
56
58
import io .perfmark .PerfMark ;
57
59
import io .perfmark .Tag ;
@@ -125,6 +127,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
125
127
private final InternalChannelz channelz ;
126
128
private final CallTracer serverCallTracer ;
127
129
private final Deadline .Ticker ticker ;
130
+ private final ServerCallExecutorSupplier executorSupplier ;
128
131
129
132
/**
130
133
* Construct a server.
@@ -159,6 +162,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
159
162
this .serverCallTracer = builder .callTracerFactory .create ();
160
163
this .ticker = checkNotNull (builder .ticker , "ticker" );
161
164
channelz .addServer (this );
165
+ this .executorSupplier = builder .executorSupplier ;
162
166
}
163
167
164
168
/**
@@ -469,11 +473,11 @@ private void streamCreatedInternal(
469
473
final Executor wrappedExecutor ;
470
474
// This is a performance optimization that avoids the synchronization and queuing overhead
471
475
// that comes with SerializingExecutor.
472
- if (executor == directExecutor ()) {
476
+ if (executorSupplier != null || executor != directExecutor ()) {
477
+ wrappedExecutor = new SerializingExecutor (executor );
478
+ } else {
473
479
wrappedExecutor = new SerializeReentrantCallsDirectExecutor ();
474
480
stream .optimizeForDirectExecutor ();
475
- } else {
476
- wrappedExecutor = new SerializingExecutor (executor );
477
481
}
478
482
479
483
if (headers .containsKey (MESSAGE_ENCODING_KEY )) {
@@ -499,52 +503,124 @@ private void streamCreatedInternal(
499
503
500
504
final JumpToApplicationThreadServerStreamListener jumpListener
501
505
= new JumpToApplicationThreadServerStreamListener (
502
- wrappedExecutor , executor , stream , context , tag );
506
+ wrappedExecutor , executor , stream , context , tag );
503
507
stream .setListener (jumpListener );
504
- // Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks
505
- // are delivered, including any errors. Callbacks can still be triggered, but they will be
506
- // queued.
507
-
508
- final class StreamCreated extends ContextRunnable {
509
- StreamCreated () {
508
+ final SettableFuture <ServerCallParameters <?,?>> future = SettableFuture .create ();
509
+ // Run in serializing executor so jumpListener.setListener() is called before any callbacks
510
+ // are delivered, including any errors. MethodLookup() and HandleServerCall() are proactively
511
+ // queued before any callbacks are queued at serializing executor.
512
+ // MethodLookup() runs on the default executor.
513
+ // When executorSupplier is enabled, MethodLookup() may set/change the executor in the
514
+ // SerializingExecutor before it finishes running.
515
+ // Then HandleServerCall() and callbacks would switch to the executorSupplier executor.
516
+ // Otherwise, they all run on the default executor.
517
+
518
+ final class MethodLookup extends ContextRunnable {
519
+ MethodLookup () {
510
520
super (context );
511
521
}
512
522
513
523
@ Override
514
524
public void runInContext () {
515
- PerfMark .startTask ("ServerTransportListener$StreamCreated .startCall" , tag );
525
+ PerfMark .startTask ("ServerTransportListener$MethodLookup .startCall" , tag );
516
526
PerfMark .linkIn (link );
517
527
try {
518
528
runInternal ();
519
529
} finally {
520
- PerfMark .stopTask ("ServerTransportListener$StreamCreated .startCall" , tag );
530
+ PerfMark .stopTask ("ServerTransportListener$MethodLookup .startCall" , tag );
521
531
}
522
532
}
523
533
524
534
private void runInternal () {
525
- ServerStreamListener listener = NOOP_LISTENER ;
535
+ ServerMethodDefinition <?, ?> wrapMethod ;
536
+ ServerCallParameters <?, ?> callParams ;
526
537
try {
527
538
ServerMethodDefinition <?, ?> method = registry .lookupMethod (methodName );
528
539
if (method == null ) {
529
540
method = fallbackRegistry .lookupMethod (methodName , stream .getAuthority ());
530
541
}
531
542
if (method == null ) {
532
543
Status status = Status .UNIMPLEMENTED .withDescription (
533
- "Method not found: " + methodName );
544
+ "Method not found: " + methodName );
534
545
// TODO(zhangkun83): this error may be recorded by the tracer, and if it's kept in
535
546
// memory as a map whose key is the method name, this would allow a misbehaving
536
547
// client to blow up the server in-memory stats storage by sending large number of
537
548
// distinct unimplemented method
538
549
// names. (https://github.com/grpc/grpc-java/issues/2285)
539
550
stream .close (status , new Metadata ());
540
551
context .cancel (null );
552
+ future .cancel (false );
541
553
return ;
542
554
}
543
- listener = startCall (stream , methodName , method , headers , context , statsTraceCtx , tag );
555
+ wrapMethod = wrapMethod (stream , method , statsTraceCtx );
556
+ callParams = maySwitchExecutor (wrapMethod , stream , headers , context , tag );
557
+ future .set (callParams );
544
558
} catch (Throwable t ) {
545
559
stream .close (Status .fromThrowable (t ), new Metadata ());
546
560
context .cancel (null );
561
+ future .cancel (false );
547
562
throw t ;
563
+ }
564
+ }
565
+
566
+ private <ReqT , RespT > ServerCallParameters <ReqT , RespT > maySwitchExecutor (
567
+ final ServerMethodDefinition <ReqT , RespT > methodDef ,
568
+ final ServerStream stream ,
569
+ final Metadata headers ,
570
+ final Context .CancellableContext context ,
571
+ final Tag tag ) {
572
+ final ServerCallImpl <ReqT , RespT > call = new ServerCallImpl <>(
573
+ stream ,
574
+ methodDef .getMethodDescriptor (),
575
+ headers ,
576
+ context ,
577
+ decompressorRegistry ,
578
+ compressorRegistry ,
579
+ serverCallTracer ,
580
+ tag );
581
+ if (executorSupplier != null ) {
582
+ Executor switchingExecutor = executorSupplier .getExecutor (call , headers );
583
+ if (switchingExecutor != null ) {
584
+ ((SerializingExecutor )wrappedExecutor ).setExecutor (switchingExecutor );
585
+ }
586
+ }
587
+ return new ServerCallParameters <>(call , methodDef .getServerCallHandler ());
588
+ }
589
+ }
590
+
591
+ final class HandleServerCall extends ContextRunnable {
592
+ HandleServerCall () {
593
+ super (context );
594
+ }
595
+
596
+ @ Override
597
+ public void runInContext () {
598
+ PerfMark .startTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
599
+ PerfMark .linkIn (link );
600
+ try {
601
+ runInternal ();
602
+ } finally {
603
+ PerfMark .stopTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
604
+ }
605
+ }
606
+
607
+ private void runInternal () {
608
+ ServerStreamListener listener = NOOP_LISTENER ;
609
+ ServerCallParameters <?,?> callParameters ;
610
+ try {
611
+ if (future .isCancelled ()) {
612
+ return ;
613
+ }
614
+ if (!future .isDone () || (callParameters = future .get ()) == null ) {
615
+ Status status = Status .INTERNAL .withDescription (
616
+ "Unexpected failure retrieving server call parameters." );
617
+ throw new StatusException (status );
618
+ }
619
+ listener = startWrappedCall (methodName , callParameters , headers );
620
+ } catch (Throwable ex ) {
621
+ stream .close (Status .fromThrowable (ex ), new Metadata ());
622
+ context .cancel (null );
623
+ throw new IllegalStateException (ex );
548
624
} finally {
549
625
jumpListener .setListener (listener );
550
626
}
@@ -568,7 +644,8 @@ public void cancelled(Context context) {
568
644
}
569
645
}
570
646
571
- wrappedExecutor .execute (new StreamCreated ());
647
+ wrappedExecutor .execute (new MethodLookup ());
648
+ wrappedExecutor .execute (new HandleServerCall ());
572
649
}
573
650
574
651
private Context .CancellableContext createContext (
@@ -593,9 +670,8 @@ private Context.CancellableContext createContext(
593
670
}
594
671
595
672
/** Never returns {@code null}. */
596
- private <ReqT , RespT > ServerStreamListener startCall (ServerStream stream , String fullMethodName ,
597
- ServerMethodDefinition <ReqT , RespT > methodDef , Metadata headers ,
598
- Context .CancellableContext context , StatsTraceContext statsTraceCtx , Tag tag ) {
673
+ private <ReqT , RespT > ServerMethodDefinition <?,?> wrapMethod (ServerStream stream ,
674
+ ServerMethodDefinition <ReqT , RespT > methodDef , StatsTraceContext statsTraceCtx ) {
599
675
// TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
600
676
statsTraceCtx .serverCallStarted (
601
677
new ServerCallInfoImpl <>(
@@ -609,34 +685,31 @@ private <ReqT, RespT> ServerStreamListener startCall(ServerStream stream, String
609
685
ServerMethodDefinition <ReqT , RespT > interceptedDef = methodDef .withServerCallHandler (handler );
610
686
ServerMethodDefinition <?, ?> wMethodDef = binlog == null
611
687
? interceptedDef : binlog .wrapMethodDefinition (interceptedDef );
612
- return startWrappedCall (fullMethodName , wMethodDef , stream , headers , context , tag );
688
+ return wMethodDef ;
689
+ }
690
+
691
+ private final class ServerCallParameters <ReqT , RespT > {
692
+ ServerCallImpl <ReqT , RespT > call ;
693
+ ServerCallHandler <ReqT , RespT > callHandler ;
694
+
695
+ public ServerCallParameters (ServerCallImpl <ReqT , RespT > call ,
696
+ ServerCallHandler <ReqT , RespT > callHandler ) {
697
+ this .call = call ;
698
+ this .callHandler = callHandler ;
699
+ }
613
700
}
614
701
615
702
private <WReqT , WRespT > ServerStreamListener startWrappedCall (
616
703
String fullMethodName ,
617
- ServerMethodDefinition <WReqT , WRespT > methodDef ,
618
- ServerStream stream ,
619
- Metadata headers ,
620
- Context .CancellableContext context ,
621
- Tag tag ) {
622
-
623
- ServerCallImpl <WReqT , WRespT > call = new ServerCallImpl <>(
624
- stream ,
625
- methodDef .getMethodDescriptor (),
626
- headers ,
627
- context ,
628
- decompressorRegistry ,
629
- compressorRegistry ,
630
- serverCallTracer ,
631
- tag );
632
-
633
- ServerCall .Listener <WReqT > listener =
634
- methodDef .getServerCallHandler ().startCall (call , headers );
635
- if (listener == null ) {
704
+ ServerCallParameters <WReqT , WRespT > params ,
705
+ Metadata headers ) {
706
+ ServerCall .Listener <WReqT > callListener =
707
+ params .callHandler .startCall (params .call , headers );
708
+ if (callListener == null ) {
636
709
throw new NullPointerException (
637
- "startCall() returned a null listener for method " + fullMethodName );
710
+ "startCall() returned a null listener for method " + fullMethodName );
638
711
}
639
- return call .newServerStreamListener (listener );
712
+ return params . call .newServerStreamListener (callListener );
640
713
}
641
714
}
642
715
0 commit comments