Skip to content

Commit fac2b04

Browse files
committed
support mysql protocol connection attributes (vitessio#18548)
Signed-off-by: Michael Demmer <[email protected]>
1 parent 15e1e38 commit fac2b04

File tree

5 files changed

+145
-32
lines changed

5 files changed

+145
-32
lines changed

go/mysql/client.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ type connectResult struct {
4848
// FIXME(alainjobart) once we have more of a server side, add test cases
4949
// to cover all failure scenarios.
5050
func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
51+
return ConnectWithAttributes(ctx, params, ConnectionAttributes{})
52+
}
53+
54+
// ConnectWithAttributes creates a connection to a server with connection attributes.
55+
// It then handles the initial handshake.
56+
//
57+
// If context is canceled before the end of the process, this function
58+
// will return nil, ctx.Err().
59+
func ConnectWithAttributes(ctx context.Context, params *ConnParams, attributes ConnectionAttributes) (*Conn, error) {
5160
if params.ConnectTimeoutMs != 0 {
5261
var cancel context.CancelFunc
5362
ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond)
@@ -116,7 +125,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
116125
// make any read or write just return with an error
117126
// right away.
118127
status <- connectResult{
119-
err: c.clientHandshake(params),
128+
err: c.clientHandshake(params, attributes),
120129
}
121130
}()
122131

@@ -198,7 +207,7 @@ func (c *Conn) Ping() error {
198207
// clientHandshake handles the client side of the handshake.
199208
// Note the connection can be closed while this is running.
200209
// Returns a SQLError.
201-
func (c *Conn) clientHandshake(params *ConnParams) error {
210+
func (c *Conn) clientHandshake(params *ConnParams, attributes ConnectionAttributes) error {
202211
// if EnableQueryInfo is set, make sure that all queries starting with the handshake
203212
// will actually process the INFO fields in QUERY_OK packets
204213
if params.EnableQueryInfo {
@@ -295,9 +304,14 @@ func (c *Conn) clientHandshake(params *ConnParams) error {
295304
return sqlerror.NewSQLError(sqlerror.CRSSLConnectionError, sqlerror.SSUnknownSQLState, "server doesn't support ClientSessionTrack but client asked for it")
296305
}
297306

307+
// Connection attributes.
308+
if capabilities&CapabilityClientConnAttr != 0 && len(attributes) > 0 {
309+
c.Capabilities |= CapabilityClientConnAttr
310+
}
311+
298312
// Build and send our handshake response 41.
299313
// Note this one will never have SSL flag on.
300-
if err := c.writeHandshakeResponse41(capabilities, scrambledPassword, uint8(params.Charset), params); err != nil {
314+
if err := c.writeHandshakeResponse41(capabilities, scrambledPassword, uint8(params.Charset), params, attributes); err != nil {
301315
return err
302316
}
303317

@@ -527,7 +541,7 @@ const CapabilityFlagsSsl = CapabilityFlags |
527541

528542
// writeHandshakeResponse41 writes the handshake response.
529543
// Returns a SQLError.
530-
func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, characterSet uint8, params *ConnParams) error {
544+
func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, characterSet uint8, params *ConnParams, attributes ConnectionAttributes) error {
531545
// Build our flags.
532546
capabilityFlags := CapabilityFlags |
533547
// If the server supported
@@ -564,6 +578,17 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
564578
length++
565579
}
566580

581+
// If the server supports CapabilityClientConnAttr and there are attributes to be
582+
// sent, then calculate the length of the attributes and include it in the overall length.
583+
var attrLength int
584+
if capabilities&CapabilityClientConnAttr != 0 && len(attributes) > 0 {
585+
capabilityFlags |= CapabilityClientConnAttr
586+
for key, value := range attributes {
587+
attrLength += lenEncStringSize(key) + lenEncStringSize(value)
588+
}
589+
length += lenEncIntSize(uint64(attrLength)) + attrLength
590+
}
591+
567592
data, pos := c.startEphemeralPacketWithHeader(length)
568593

569594
// Client capability flags.
@@ -600,6 +625,16 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
600625
// Assume native client during response
601626
pos = writeNullString(data, pos, string(c.authPluginName))
602627

628+
// Client conn attributes
629+
if attrLength > 0 {
630+
pos = writeLenEncInt(data, pos, uint64(attrLength))
631+
632+
for key, value := range attributes {
633+
pos = writeLenEncString(data, pos, key)
634+
pos = writeLenEncString(data, pos, value)
635+
}
636+
}
637+
603638
// Sanity-check the length.
604639
if pos != len(data) {
605640
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "writeHandshakeResponse41: only packed %v bytes, out of %v allocated", pos, len(data))

go/mysql/conn.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ type Conn struct {
128128
// It is set during the initial handshake.
129129
UserData Getter
130130

131+
// ConnectionAttributes stores arbitrary client-supplied attributes sent in the
132+
// connection handshake.
133+
Attributes ConnectionAttributes
134+
131135
bufferedReader *bufio.Reader
132136
flushTimer *time.Timer
133137
flushDelay time.Duration

go/mysql/constants.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ const (
3737
// implemented authentication methods.
3838
type AuthMethodDescription string
3939

40+
// ConnectionAttributes is a map of key/value pairs sent by the client during
41+
// the connection phase.
42+
type ConnectionAttributes map[string]string
43+
4044
// Supported auth forms.
4145
const (
4246
// MysqlNativePassword uses a salt and transmits a hash on the wire.

go/mysql/server.go

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, ch
688688
}
689689

690690
// parseClientHandshakePacket parses the handshake sent by the client.
691-
// Returns the username, auth method, auth data, error.
691+
// Returns the username, auth method, auth data, connection attributes, error.
692692
// The original data is not pointed at, and can be freed.
693693
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) {
694694
pos := 0
@@ -806,58 +806,43 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
806806

807807
// Decode connection attributes send by the client
808808
if clientFlags&CapabilityClientConnAttr != 0 {
809-
if _, _, err := parseConnAttrs(data, pos); err != nil {
809+
clientAttributes, _, err := parseConnAttrs(data, pos)
810+
if err != nil {
810811
log.Warningf("Decode connection attributes send by the client: %v", err)
811812
}
813+
814+
c.Attributes = clientAttributes
812815
}
813816

814817
return username, AuthMethodDescription(authMethod), authResponse, nil
815818
}
816819

817-
func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
818-
var attrLen uint64
820+
func parseConnAttrs(data []byte, pos int) (ConnectionAttributes, int, error) {
821+
attrs := make(map[string]string)
819822

820823
attrLen, pos, ok := readLenEncInt(data, pos)
821824
if !ok {
822825
return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attributes variable length")
823826
}
824827

825-
var attrLenRead uint64
826-
827-
attrs := make(map[string]string)
828-
829-
for attrLenRead < attrLen {
830-
var keyLen byte
831-
keyLen, pos, ok = readByte(data, pos)
832-
if !ok {
833-
return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key length")
834-
}
835-
attrLenRead += uint64(keyLen) + 1
828+
addrEndPos := pos + int(attrLen)
836829

837-
var connAttrKey []byte
838-
connAttrKey, pos, ok = readBytes(data, pos, int(keyLen))
830+
var key, value string
831+
for pos < addrEndPos {
832+
key, pos, ok = readLenEncString(data, pos)
839833
if !ok {
840834
return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key")
841835
}
842836

843-
var valLen byte
844-
valLen, pos, ok = readByte(data, pos)
845-
if !ok {
846-
return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value length")
847-
}
848-
attrLenRead += uint64(valLen) + 1
849-
850-
var connAttrVal []byte
851-
connAttrVal, pos, ok = readBytes(data, pos, int(valLen))
837+
value, pos, ok = readLenEncString(data, pos)
852838
if !ok {
853839
return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value")
854840
}
855841

856-
attrs[string(connAttrKey[:])] = string(connAttrVal[:])
842+
attrs[key] = value
857843
}
858844

859845
return attrs, pos, nil
860-
861846
}
862847

863848
// writeAuthSwitchRequest writes an auth switch request packet.

go/mysql/server_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,91 @@ func TestClientFoundRows(t *testing.T) {
471471
c.Close()
472472
}
473473

474+
func TestConnAttrs(t *testing.T) {
475+
ctx := utils.LeakCheckContext(t)
476+
th := &testHandler{}
477+
478+
authServer := NewAuthServerStatic("", "", 0)
479+
authServer.entries["user1"] = []*AuthServerStaticEntry{{
480+
Password: "password1",
481+
UserData: "userData1",
482+
}}
483+
defer authServer.close()
484+
485+
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
486+
require.NoError(t, err, "NewListener failed")
487+
host, port := getHostPort(t, l.Addr())
488+
489+
// Test with attrs.
490+
params := &ConnParams{
491+
Host: host,
492+
Port: port,
493+
Uname: "user1",
494+
Pass: "password1",
495+
}
496+
497+
attributes := ConnectionAttributes{
498+
"key1": "value1",
499+
"k2": "v2",
500+
}
501+
502+
go l.Accept()
503+
defer cleanupListener(ctx, l, params)
504+
505+
clientConn, err := ConnectWithAttributes(ctx, params, attributes)
506+
require.NoError(t, err, "Connect failed")
507+
508+
serverConn := th.LastConn()
509+
assert.Equal(t, uint32(CapabilityClientConnAttr), clientConn.Capabilities&CapabilityClientConnAttr, "ConnAttr flag: %x, bit must be set", th.LastConn().Capabilities)
510+
assert.Equal(t, serverConn.Attributes, attributes, "attributes should be sent and parsed")
511+
512+
clientConn.Close()
513+
assert.True(t, clientConn.IsClosed(), "IsClosed should be true on Close-d connection.")
514+
515+
// Empty attrs do not even set the capability flag
516+
params = &ConnParams{
517+
Host: host,
518+
Port: port,
519+
Uname: "user1",
520+
Pass: "password1",
521+
}
522+
523+
clientConn, err = Connect(ctx, params)
524+
require.NoError(t, err, "Connect failed")
525+
526+
serverConn = th.LastConn()
527+
assert.Equal(t, uint32(0), clientConn.Capabilities&CapabilityClientConnAttr, "ConnAttr flag: %x, bit must not be set", th.LastConn().Capabilities)
528+
assert.Equal(t, 0, len(serverConn.Attributes), "attributes should be empty")
529+
530+
clientConn.Close()
531+
assert.True(t, clientConn.IsClosed(), "IsClosed should be true on Close-d connection.")
532+
533+
// Test long attributes more than 255 bytes
534+
params = &ConnParams{
535+
Host: host,
536+
Port: port,
537+
Uname: "user1",
538+
Pass: "password1",
539+
}
540+
541+
longAttributes := ConnectionAttributes{
542+
"short": strings.Repeat("a", 10),
543+
"long": strings.Repeat("b", 256),
544+
"longer": strings.Repeat("c", 1024*1024),
545+
}
546+
547+
clientConn, err = ConnectWithAttributes(ctx, params, longAttributes)
548+
require.NoError(t, err, "Connect failed")
549+
550+
serverConn = th.LastConn()
551+
assert.Equal(t, uint32(CapabilityClientConnAttr), clientConn.Capabilities&CapabilityClientConnAttr, "ConnAttr flag: %x, bit must be set", th.LastConn().Capabilities)
552+
assert.Equal(t, serverConn.Attributes, longAttributes, "attributes should be sent and parsed")
553+
554+
clientConn.Close()
555+
assert.True(t, clientConn.IsClosed(), "IsClosed should be true on Close-d connection.")
556+
557+
}
558+
474559
func TestConnCounts(t *testing.T) {
475560
th := &testHandler{}
476561

0 commit comments

Comments
 (0)