Skip to content
Prev Previous commit
Next Next commit
util/net: implement insecure gRPC
Also removes grpcutil.ListenAndServeGRPC to make sure we're always
exercising gRPC's ServeHTTP in our tests, rather than serving on the
grpc.Server directly.
  • Loading branch information
tamird committed Feb 17, 2016
commit c036c1503cdb3756da4f5c911b77253e62706a45
9 changes: 4 additions & 5 deletions gossip/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/cockroachdb/cockroach/roachpb"
"github.com/cockroachdb/cockroach/rpc"
"github.com/cockroachdb/cockroach/util"
"github.com/cockroachdb/cockroach/util/grpcutil"
"github.com/cockroachdb/cockroach/util/hlc"
"github.com/cockroachdb/cockroach/util/leaktest"
"github.com/cockroachdb/cockroach/util/stop"
Expand All @@ -46,7 +45,7 @@ func startGossip(nodeID roachpb.NodeID, stopper *stop.Stopper, t *testing.T) *Go
if err != nil {
t.Fatal(err)
}
ln, err := grpcutil.ListenAndServeGRPC(stopper, server, addr, tlsConfig)
ln, err := util.ListenAndServe(stopper, server, addr, tlsConfig)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -111,7 +110,7 @@ func startFakeServerGossips(t *testing.T) (local *Gossip, remote *fakeGossipServ
if err != nil {
t.Fatal(err)
}
lln, err := grpcutil.ListenAndServeGRPC(stopper, lserver, laddr, lTLSConfig)
lln, err := util.ListenAndServe(stopper, lserver, laddr, lTLSConfig)
if err != nil {
t.Fatal(err)
}
Expand All @@ -127,7 +126,7 @@ func startFakeServerGossips(t *testing.T) (local *Gossip, remote *fakeGossipServ
if err != nil {
t.Fatal(err)
}
rln, err := grpcutil.ListenAndServeGRPC(stopper, rserver, raddr, rTLSConfig)
rln, err := util.ListenAndServe(stopper, rserver, raddr, rTLSConfig)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -333,7 +332,7 @@ func TestClientRegisterWithInitNodeID(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ln, err := grpcutil.ListenAndServeGRPC(stopper, server, addr, TLSConfig)
ln, err := util.ListenAndServe(stopper, server, addr, TLSConfig)
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 1 addition & 2 deletions gossip/simulation/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/cockroachdb/cockroach/rpc"
"github.com/cockroachdb/cockroach/util"
"github.com/cockroachdb/cockroach/util/encoding"
"github.com/cockroachdb/cockroach/util/grpcutil"
"github.com/cockroachdb/cockroach/util/hlc"
"github.com/cockroachdb/cockroach/util/log"
"github.com/cockroachdb/cockroach/util/stop"
Expand Down Expand Up @@ -93,7 +92,7 @@ func NewNetwork(nodeCount int) *Network {
func (n *Network) CreateNode() (*Node, error) {
server := grpc.NewServer()
testAddr := util.CreateTestAddr("tcp")
ln, err := grpcutil.ListenAndServeGRPC(n.Stopper, server, testAddr, n.tlsConfig)
ln, err := util.ListenAndServe(n.Stopper, server, testAddr, n.tlsConfig)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion storage/raft_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (lt *localRPCTransport) Listen(id roachpb.StoreID, handler RaftMessageHandl
RegisterMultiRaftServer(grpcServer, handler)

addr := util.CreateTestAddr("tcp")
ln, err := grpcutil.ListenAndServeGRPC(lt.stopper, grpcServer, addr, nil)
ln, err := util.ListenAndServe(lt.stopper, grpcServer, addr, nil)
if err != nil {
return err
}
Expand Down
27 changes: 0 additions & 27 deletions util/grpcutil/grpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
package grpcutil

import (
"crypto/tls"
"net"
"net/http"
"strings"

Expand All @@ -29,34 +27,9 @@ import (
"google.golang.org/grpc/transport"

"github.com/cockroachdb/cockroach/util"
"github.com/cockroachdb/cockroach/util/log"
"github.com/cockroachdb/cockroach/util/stop"
)

// ListenAndServeGRPC creates a listener and serves server on it, closing
// the listener when signalled by the stopper.
func ListenAndServeGRPC(stopper *stop.Stopper, server *grpc.Server, addr net.Addr, config *tls.Config) (net.Listener, error) {
ln, err := util.Listen(addr, config)
if err != nil {
return nil, err
}

stopper.RunWorker(func() {
if err := server.Serve(ln); err != nil && !util.IsClosedConnection(err) {
log.Fatal(err)
}
})

stopper.RunWorker(func() {
<-stopper.ShouldDrain()
if err := ln.Close(); err != nil {
log.Fatal(err)
}
})

return ln, nil
}

// GRPCHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
// connections or otherHandler otherwise.
func GRPCHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {
Expand Down
68 changes: 65 additions & 3 deletions util/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package util

import (
"bytes"
"crypto/tls"
"io"
"log"
"net"
"net/http"
Expand All @@ -29,13 +31,51 @@ import (
"github.com/cockroachdb/cockroach/util/stop"
)

type replayableConn struct {
net.Conn
buf bytes.Buffer
reader io.Reader
}

// Do not call `replay` more than once, bad things will happen.
func (bc *replayableConn) replay() *replayableConn {
bc.reader = io.MultiReader(&bc.buf, bc.Conn)
return bc
}

func (bc *replayableConn) Read(b []byte) (int, error) {
return bc.reader.Read(b)
}

func newBufferedConn(conn net.Conn) *replayableConn {
bc := replayableConn{Conn: conn}
bc.reader = io.TeeReader(conn, &bc.buf)
return &bc
}

type replayableConnListener struct {
net.Listener
}

func (ml *replayableConnListener) Accept() (net.Conn, error) {
conn, err := ml.Listener.Accept()
if err == nil {
conn = newBufferedConn(conn)
}
return conn, err
}

// Listen delegates to `net.Listen` and, if tlsConfig is not nil, to `tls.NewListener`.
// The returned listener's Addr() method will return an address with the hostname unresovled,
// which means it can be used to initiate TLS connections.
func Listen(addr net.Addr, tlsConfig *tls.Config) (net.Listener, error) {
ln, err := net.Listen(addr.Network(), addr.String())
if err == nil && tlsConfig != nil {
ln = tls.NewListener(ln, tlsConfig)
if err == nil {
if tlsConfig != nil {
ln = tls.NewListener(ln, tlsConfig)
} else {
ln = &replayableConnListener{ln}
}
}

return ln, err
Expand Down Expand Up @@ -66,7 +106,29 @@ func ListenAndServe(stopper *stop.Stopper, handler http.Handler, addr net.Addr,
mu.Unlock()
},
}
if err := http2.ConfigureServer(&httpServer, nil); err != nil {

var http2Server http2.Server

if tlsConfig == nil {
connOpts := http2.ServeConnOpts{
BaseConfig: &httpServer,
Handler: handler,
}

httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor == 2 {
if conn, _, err := w.(http.Hijacker).Hijack(); err == nil {
http2Server.ServeConn(conn.(*replayableConn).replay(), &connOpts)
} else {
log.Fatal(err)
}
} else {
handler.ServeHTTP(w, r)
}
})
}

if err := http2.ConfigureServer(&httpServer, &http2Server); err != nil {
return nil, err
}

Expand Down