cyberark/secretless-broker

View on GitHub
internal/plugin/connectors/tcp/mysql/protocol/protocol.go

Summary

Maintainability
D
1 day
Test Coverage
C
75%
/*
MIT License

Copyright (c) 2017 Aleksandr Fedotov

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

package protocol

import (
    "bytes"
    "crypto/rand"
    "crypto/rsa"
    "crypto/sha1"
    "crypto/sha256"
    "crypto/x509"
    "encoding/binary"
    "encoding/pem"
    "errors"
    "fmt"
    "io"
)

// ErrInvalidPacketLength is for invalid packet lengths
var ErrInvalidPacketLength = errors.New("Protocol: Invalid packet length")

// ErrInvalidPacketType is for invalid packet types
var ErrInvalidPacketType = errors.New("Protocol: Invalid packet type")

// ErrFieldTypeNotImplementedYet is for field types that are not yet implemented
var ErrFieldTypeNotImplementedYet = errors.New("Protocol: Required field type not implemented yet")

// UnpackErrResponse decodes ERR_Packet from server.
// Part of basic packet structure shown below.
//
// int<3> PacketLength
// int<1> PacketNumber
// int<1> PacketType (0xFF)
// int<2> ErrorCode
// if clientCapabilities & clientProtocol41
//
//    {
//            string<1> SqlStateMarker (#)
//            string<5> SqlState
//    }
//
// string<EOF> Error
func UnpackErrResponse(data []byte) error {
    // Min packet length =
    // header(4 bytes)
    // + PacketType(1 byte)
    // + ErrorCode(2 bytes)
    // + string<EOF>(at least 1 byte)
    if err := CheckPacketLength(8, data); err != nil {
        return err
    }
    pos := 0

    // skip header
    pos = pos + 4

    // skip PacketType
    // 0xff [1 byte]
    pos++

    // Error Number [16 bit uint]
    errno := binary.LittleEndian.Uint16(data[pos : pos+2])
    pos = pos + 2

    sqlstate := ""
    // SQL State [optional: # + 5bytes string]
    if data[pos] == '#' {
        pos++

        sqlstate = string(data[pos : pos+5])
        pos = pos + 5
    }

    // Error Message [string]
    return Error{
        Code:     errno,
        SQLState: sqlstate,
        Message:  string(data[pos:]),
    }
}

// GetPacketType extracts the PacketType byte
// Part of basic packet structure shown below.
//
//    int<3> PacketLength
//    int<1> PacketNumber
//    int<1> PacketType (0xFF)
//    ... more ...
func GetPacketType(packet []byte) byte {
    return packet[4]
}

// OkResponse represents packet sent from the server to the client to signal successful completion of a command
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html
type OkResponse struct {
    PacketType   byte
    AffectedRows uint64
    LastInsertID uint64
    StatusFlags  uint16
    Warnings     uint16
}

// UnpackOkResponse decodes OK_Packet from server.
// Part of basic packet structure shown below.
//
// int<3> PacketLength
// int<1> PacketNumber
// int<1> PacketType (0x00 or 0xFE)
// int<lenenc> AffectedRows
// int<lenenc> LastInsertID
// ... more ...
func UnpackOkResponse(packet []byte) (*OkResponse, error) {

    // Min packet length = header(4 bytes) + PacketType(1 byte)
    if err := CheckPacketLength(5, packet); err != nil {
        return nil, err
    }

    r := bytes.NewReader(packet)

    // Skip packet header
    if _, err := GetPacketHeader(r); err != nil {
        return nil, err
    }

    // Read header, validate OK
    packetType, err := r.ReadByte()
    if err != nil {
        return nil, err
    }
    if packetType != ResponseOk {
        return nil, errors.New("Malformed packet")
    }

    // Read affected rows (expected value: 0 for auth)
    affectedRows, err := ReadLenEncodedInteger(r)
    if err != nil {
        return nil, err
    }

    // Read last insert ID (expected value: 0 for auth)
    lastInsertID, err := ReadLenEncodedInteger(r)
    if err != nil {
        return nil, err
    }

    // Read status flags
    statusBuf := make([]byte, 2)
    if _, err := r.Read(statusBuf); err != nil {
        return nil, err
    }
    status := binary.LittleEndian.Uint16(statusBuf)

    // Read warnings
    warningsBuf := make([]byte, 2)
    if _, err := r.Read(warningsBuf); err != nil {
        return nil, err
    }
    warnings := binary.LittleEndian.Uint16(warningsBuf)

    return &OkResponse{
        PacketType:   packetType,
        AffectedRows: affectedRows,
        LastInsertID: lastInsertID,
        StatusFlags:  status,
        Warnings:     warnings}, nil
}

// HandshakeV10 represents sever's initial handshake packet
// See https://mariadb.com/kb/en/mariadb/1-connecting-connecting/#initial-handshake-packet
type HandshakeV10 struct {
    ProtocolVersion    byte
    SequenceID         uint8
    ServerVersion      string
    ConnectionID       uint32
    StatusFlags        uint16
    CharacterSet       uint8
    ServerCapabilities uint32
    AuthPlugin         string
    Salt               []byte
}

// UnpackHandshakeV10 decodes initial handshake request from server.
// Basic packet structure shown below.
// See http://imysql.com/mysql-internal-manual/connection-phase-packets.html#packet-Protocol::HandshakeV10
//
// int<3> PacketLength
// int<1> PacketNumber
// int<1> ProtocolVersion
// string<NUL> ServerVersion
// int<4> ConnectionID
// string<8> AuthPluginDataPart1 (authentication seed)
// string<1> Reserved (always 0x00)
// int<2> ServerCapabilities (1st part)
// int<1> ServerDefaultCollation
// int<2> StatusFlags
// int<2> ServerCapabilities (2nd part)
// if capabilities & clientPluginAuth
//
//    {
//            int<1> AuthPluginDataLength
//    }
//
// else
//
//    {
//            int<1> 0x00
//    }
//
// string<10> Reserved (all 0x00)
// if capabilities & clientSecureConnection
//
//    {
//            string[$len] AuthPluginDataPart2 ($len=MAX(13, AuthPluginDataLength - 8))
//    }
//
// if capabilities & clientPluginAuth
//
//    {
//            string[NUL] AuthPluginName
//    }
func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) {
    r := bytes.NewReader(packet)

    // Header
    header, err := GetPacketHeader(r)
    if err != nil {
        return nil, err
    }

    // Read ProtocolVersion
    protoVersion, _ := r.ReadByte()

    // Read ServerVersion
    serverVersion := ReadNullTerminatedString(r)

    // Read ConnectionID
    connectionIDBuf := make([]byte, 4)
    if _, err := r.Read(connectionIDBuf); err != nil {
        return nil, err
    }
    connectionID := binary.LittleEndian.Uint32(connectionIDBuf)

    // Read AuthPluginDataPart1
    var salt []byte
    salt8 := make([]byte, 8)
    if _, err := r.Read(salt8); err != nil {
        return nil, err
    }
    salt = append(salt, salt8...)

    // Skip filler
    if _, err := r.ReadByte(); err != nil {
        return nil, err
    }

    // Read ServerCapabilities
    serverCapabilitiesLowerBuf := make([]byte, 2)
    if _, err := r.Read(serverCapabilitiesLowerBuf); err != nil {
        return nil, err
    }

    // Read ServerCharacterSet and StatusFlags
    serverCharacterSet, err := r.ReadByte()
    if err != nil {
        return nil, err
    }
    serverStatusFlagsBuf := make([]byte, 2)
    if _, err := r.Read(serverStatusFlagsBuf); err != nil {
        return nil, err
    }
    serverStatusFlags := binary.LittleEndian.Uint16(serverStatusFlagsBuf)

    // Read ExServerCapabilities
    serverCapabilitiesHigherBuf := make([]byte, 2)
    if _, err := r.Read(serverCapabilitiesHigherBuf); err != nil {
        return nil, err
    }

    // Compose ServerCapabilities from 2 bufs
    var serverCapabilitiesBuf []byte
    serverCapabilitiesBuf = append(serverCapabilitiesBuf, serverCapabilitiesLowerBuf...)
    serverCapabilitiesBuf = append(serverCapabilitiesBuf, serverCapabilitiesHigherBuf...)
    serverCapabilities := binary.LittleEndian.Uint32(serverCapabilitiesBuf)

    // Get length of AuthnPluginDataPart2
    // or read in empty byte if not included
    var authPluginDataLength byte
    if serverCapabilities&ClientPluginAuth > 0 {
        var err error
        authPluginDataLength, err = r.ReadByte()
        if err != nil {
            return nil, err
        }
    } else {
        if _, err := r.ReadByte(); err != nil {
            return nil, err
        }
    }

    // Skip reserved (all 0x00)
    if _, err := r.Seek(10, io.SeekCurrent); err != nil {
        return nil, err
    }

    // Get AuthnPluginDataPart2
    var numBytes int
    if serverCapabilities&ClientSecureConnection != 0 {
        numBytes = int(authPluginDataLength) - 8
        if numBytes < 0 || numBytes > 13 {
            numBytes = 13
        }

        salt2 := make([]byte, numBytes)
        if _, err := r.Read(salt2); err != nil {
            return nil, err
        }

        // the last byte has to be 0, and is not part of the data
        if salt2[numBytes-1] != 0 {
            return nil, errors.New("Malformed packet")
        }
        salt = append(salt, salt2[:numBytes-1]...)
    }

    var authPlugin string
    if serverCapabilities&ClientPluginAuth != 0 {
        authPlugin = ReadNullTerminatedString(r)
    }

    return &HandshakeV10{
        SequenceID:         header[3],
        ProtocolVersion:    protoVersion,
        ServerVersion:      serverVersion,
        ConnectionID:       connectionID,
        ServerCapabilities: serverCapabilities,
        AuthPlugin:         authPlugin,
        Salt:               salt,
        StatusFlags:        serverStatusFlags,
        CharacterSet:       serverCharacterSet,
    }, nil
}

// PackHandshakeV10 takes in a HandshakeResponse41 object and
// returns a handshake response packet
func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) {
    // Create a buffer to write the packet data
    buffer := new(bytes.Buffer)

    // Write ProtocolVersion (int<1>)
    binary.Write(buffer, binary.LittleEndian, serverHandshake.ProtocolVersion)

    // Write ServerVersion (string<NUL>)
    buffer.WriteString(serverHandshake.ServerVersion)
    buffer.WriteByte(0)

    // Write ConnectionID (int<4>)
    binary.Write(buffer, binary.LittleEndian, serverHandshake.ConnectionID)

    // Write AuthPluginDataPart1 (string<8>)
    buffer.Write(serverHandshake.Salt[:8]) // Write the first 8 bytes of the salt

    // Write Reserved (int<1>)
    buffer.WriteByte(0)

    // Write ServerCapabilities (int<2>)
    binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities&0xFFFF))

    // Write ServerCharacterSet (int<1>)
    buffer.WriteByte(serverHandshake.CharacterSet)

    // Write StatusFlags (int<2>)
    binary.Write(buffer, binary.LittleEndian, serverHandshake.StatusFlags)

    // Write ServerCapabilities (int<2>), the higher part
    binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities>>16))

    // Write AuthPluginDataLength (int<1>) if required
    if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 {
        buffer.WriteByte(byte(len(serverHandshake.Salt) + 1))
    }

    // Write Reserved (string<10>)
    buffer.Write(make([]byte, 10))

    // Calculate the length of AuthPluginDataPart2
    var authPluginDataLength byte
    if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 {
        numBytes := len(serverHandshake.Salt) - 8
        if numBytes > 13 {
            numBytes = 13
        }
        authPluginDataLength = byte(numBytes)
    }

    // Write AuthPluginDataPart2 (string[$len]) if required
    if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 {
        buffer.Write(serverHandshake.Salt[8 : 8+int(authPluginDataLength)])
        buffer.WriteByte(0)
    }

    // Write AuthPluginName (string<NUL>) if required
    if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 {
        buffer.WriteString(serverHandshake.AuthPlugin)
        buffer.WriteByte(0)
    }

    return AddHeaderToPacket(serverHandshake.SequenceID, buffer.Bytes()), nil
}

// RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server
// Handshake Packet.  Secretless needs to do this to force the client to
// communicate with Secretless without using SSL.  That half of the connection
// is insecure by design.  Secretless then (usually) adds SSL for the other
// half of the communication -- between Secretless and the MySQL server.
func RemoveSSLFromHandshakeV10(packet []byte) ([]byte, error) {
    r := bytes.NewReader(packet)
    initialLen := r.Len()

    // Skip packet header
    if _, err := GetPacketHeader(r); err != nil {
        return nil, err
    }

    // Read ProtocolVersion
    r.ReadByte()

    // Read ServerVersion
    ReadNullTerminatedString(r)

    // Read ConnectionID
    connectionIDBuf := make([]byte, 4)
    if _, err := r.Read(connectionIDBuf); err != nil {
        return nil, err
    }

    // Read AuthPluginDataPart1
    var salt []byte
    salt8 := make([]byte, 8)
    if _, err := r.Read(salt8); err != nil {
        return nil, err
    }
    salt = append(salt, salt8...)

    // Skip filler
    if _, err := r.ReadByte(); err != nil {
        return nil, err
    }

    serverCapabilitiesIndex := initialLen - r.Len()
    // Read ServerCapabilities
    serverCapabilitiesLowerBuf := make([]byte, 2)
    if _, err := r.Read(serverCapabilitiesLowerBuf); err != nil {
        return nil, err
    }

    // Skip ServerDefaultCollation and StatusFlags
    if _, err := r.Seek(3, io.SeekCurrent); err != nil {
        return nil, err
    }

    // Read ExServerCapabilities
    exServerCapabilitiesIndex := initialLen - r.Len()
    serverCapabilitiesHigherBuf := make([]byte, 2)
    if _, err := r.Read(serverCapabilitiesHigherBuf); err != nil {
        return nil, err
    }

    newPacket := make([]byte, len(packet))
    copy(newPacket, packet)

    // Compose ServerCapabilities from 2 bufs
    var serverCapabilitiesBuf []byte
    serverCapabilitiesBuf = append(serverCapabilitiesBuf, serverCapabilitiesLowerBuf...)
    serverCapabilitiesBuf = append(serverCapabilitiesBuf, serverCapabilitiesHigherBuf...)
    serverCapabilities := binary.LittleEndian.Uint32(serverCapabilitiesBuf)

    // Remove ClientSSL from serverCapabilities
    serverCapabilities = serverCapabilities ^ ClientSSL

    // update Lower part of the capability flags.
    writeUint16(newPacket, serverCapabilitiesIndex, uint16(serverCapabilities))

    // update Upper part of the capability flags.
    writeUint16(newPacket, exServerCapabilitiesIndex, uint16(serverCapabilities>>16))

    return newPacket, nil
}

// writes Uint16 starting from a given position in a byte slice
func writeUint16(data []byte, pos int, value uint16) {
    data[pos] = byte(value)
    data[pos+1] = byte(value >> 8)
}

// HandshakeResponse41 represents handshake response packet sent by 4.1+ clients supporting clientProtocol41 capability,
// if the server announced it in its initial handshake packet.
// See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41
//
// The format of the header is also described here:
//
//      https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html
//
//     +-------------+----------------+---------------------------------------------+
//     |    Type     |      Name      |                 Description                 |
//        +-------------+----------------+---------------------------------------------+
//     | int<3>      | payload_length | Length of the payload. The number of bytes  |
//     |             |                | in the packet beyond the initial 4 bytes    |
//     |             |                | that make up the packet header.             |
//     | int<1>      | sequence_id    | Sequence ID                                 |
//     | string<var> | payload        | [len=payload_length] payload of the packet  |
//     +-------------+----------------+---------------------------------------------+
type HandshakeResponse41 struct {
    SequenceID      uint8
    CapabilityFlags uint32
    MaxPacketSize   uint32
    ClientCharset   uint8
    Username        string
    AuthLength      int64
    AuthPluginName  string
    AuthResponse    []byte
    Database        string
    PacketTail      []byte
}

// UnpackHandshakeResponse41 decodes handshake response packet send by client.
// TODO: Add packet struct comment
// TODO: Add packet length check
func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) {
    r := bytes.NewReader(packet)

    // Skip packet header (but save in struct)
    header, err := GetPacketHeader(r)
    if err != nil {
        return nil, err
    }

    // Read CapabilityFlags
    clientCapabilitiesBuf := make([]byte, 4)
    if _, err := r.Read(clientCapabilitiesBuf); err != nil {
        return nil, err
    }
    capabilityFlags := binary.LittleEndian.Uint32(clientCapabilitiesBuf)

    // Check that the server is using protocol 4.1
    if capabilityFlags&ClientProtocol41 == 0 {
        return nil, errors.New("Client Protocol mismatch")
    }

    // client requesting SSL, we don't support it
    clientRequestedSSL := capabilityFlags&ClientSSL > 0
    if clientRequestedSSL {
        return nil, errors.New("SSL Protocol mismatch")
    }

    // Read MaxPacketSize
    maxPacketSizeBuf := make([]byte, 4)
    if _, err := r.Read(maxPacketSizeBuf); err != nil {
        return nil, err
    }
    maxPacketSize := binary.LittleEndian.Uint32(maxPacketSizeBuf)

    // Read Charset
    charset, err := r.ReadByte()
    if err != nil {
        return nil, err
    }

    // Skip 23 byte buffer
    if _, err := r.Seek(23, io.SeekCurrent); err != nil {
        return nil, err
    }

    // Read Username
    username := ReadNullTerminatedString(r)

    // Read Auth
    var auth []byte
    var authLength int64
    if capabilityFlags&ClientSecureConnection > 0 {
        authLengthByte, err := r.ReadByte()
        if err != nil {
            return nil, err
        }
        authLength = int64(authLengthByte)

        auth = make([]byte, authLength)
        if _, err := r.Read(auth); err != nil {
            return nil, err
        }
    } else {
        auth = ReadNullTerminatedBytes(r)
    }

    // Read Database
    var database string
    if capabilityFlags&ClientConnectWithDB > 0 {
        database = ReadNullTerminatedString(r)
    }

    // check whether the auth method was specified
    var authPluginName string
    if capabilityFlags&ClientPluginAuth > 0 {
        authPluginName = ReadNullTerminatedString(r)
    }

    // get the rest of the packet
    var packetTail []byte
    remainingByteLen := r.Len()
    if remainingByteLen > 0 {
        packetTail = make([]byte, remainingByteLen)
        if _, err := r.Read(packetTail); err != nil {
            return nil, err
        }
    }

    return &HandshakeResponse41{
        SequenceID:      header[3],
        CapabilityFlags: capabilityFlags,
        MaxPacketSize:   maxPacketSize,
        ClientCharset:   charset,
        Username:        username,
        AuthLength:      authLength,
        AuthPluginName:  authPluginName,
        AuthResponse:    auth,
        Database:        database,
        PacketTail:      packetTail}, nil
}

// CreateAuthResponse creates an auth response for the given auth plugin
func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte, error) {
    var authResponse []byte
    var err error

    switch authPlugin {
    case "mysql_native_password":
        authResponse, err = NativePassword([]byte(password), salt)
    case "caching_sha2_password":
        authResponse = scrambleSHA256Password([]byte(password), salt)
    default:
        err = fmt.Errorf("Unknown auth plugin: %s", authPlugin)
    }

    if err != nil {
        return nil, err
    }
    return authResponse, nil
}

// InjectCredentials takes in a HandshakeResponse41 from the client, the
// salt from the server, and a username / password, and uses the salt
// from the server handshake to inject the username / password credentials into
// the client handshake response
func InjectCredentials(authPlugin string, clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) {
    authResponse, err := CreateAuthResponse(authPlugin, []byte(password), salt)
    if err != nil {
        return
    }

    clientHandshake.AuthPluginName = authPlugin
    clientHandshake.Username = username
    clientHandshake.AuthLength = int64(len(authResponse))
    clientHandshake.AuthResponse = authResponse

    return
}

// Hash password using MySQL 8+ method (SHA256)
func scrambleSHA256Password(password []byte, scramble []byte) []byte {
    if len(password) == 0 {
        return nil
    }

    // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))

    crypt := sha256.New()
    crypt.Write(password)
    message1 := crypt.Sum(nil)

    crypt.Reset()
    crypt.Write(message1)
    message1Hash := crypt.Sum(nil)

    crypt.Reset()
    crypt.Write(message1Hash)
    crypt.Write(scramble)
    message2 := crypt.Sum(nil)

    for i := range message1 {
        message1[i] ^= message2[i]
    }

    return message1
}

// PackHandshakeResponse41 takes in a HandshakeResponse41 object and
// returns a handshake response packet
func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) ([]byte, error) {

    var buf bytes.Buffer

    // write the capability flags
    capabilityFlagsBuf := make([]byte, 4)
    binary.LittleEndian.PutUint32(capabilityFlagsBuf, clientHandshake.CapabilityFlags)
    buf.Write(capabilityFlagsBuf)

    // write max packet size
    maxPacketSizeBuf := make([]byte, 4)
    binary.LittleEndian.PutUint32(maxPacketSizeBuf, clientHandshake.MaxPacketSize)
    buf.Write(maxPacketSizeBuf)

    // write 1 byte char set
    buf.WriteByte(clientHandshake.ClientCharset)

    // write string[23] reserved (all zero)
    for i := 0; i < 23; i++ {
        buf.WriteByte(0)
    }

    // write string username
    buf.WriteString(clientHandshake.Username)
    buf.WriteByte(0)

    // write auth
    if clientHandshake.CapabilityFlags&ClientSecureConnection > 0 {
        if clientHandshake.AuthLength > 0 {
            buf.WriteByte(uint8(len(clientHandshake.AuthResponse)))
            buf.Write(clientHandshake.AuthResponse)
        } else {
            buf.WriteByte(0)
        }
    } else {
        buf.Write(clientHandshake.AuthResponse)
        buf.WriteByte(0)
    }

    // write database (if set)
    if clientHandshake.CapabilityFlags&ClientConnectWithDB > 0 {
        buf.WriteString(clientHandshake.Database)
        buf.WriteByte(0)
    }

    // write auth plugin name
    buf.WriteString(clientHandshake.AuthPluginName)
    buf.WriteByte(0)

    // write tail of packet (if set)
    if len(clientHandshake.PacketTail) > 0 {
        buf.Write(clientHandshake.PacketTail)
    }

    return AddHeaderToPacket(clientHandshake.SequenceID, buf.Bytes()), nil
}

// GetLenEncodedIntegerSize returns bytes count for length encoded integer
// determined by it's 1st byte
func GetLenEncodedIntegerSize(firstByte byte) byte {
    switch firstByte {
    case 0xfc:
        return 2
    case 0xfd:
        return 3
    case 0xfe:
        return 8
    default:
        return 1
    }
}

// ReadLenEncodedInteger returns parsed length-encoded integer and it's offset.
// See https://mariadb.com/kb/en/mariadb/protocol-data-types/#length-encoded-integers
func ReadLenEncodedInteger(r *bytes.Reader) (value uint64, err error) {
    firstLenEncIntByte, err := r.ReadByte()
    if err != nil {
        return
    }

    switch firstLenEncIntByte {
    case 0xfb:
        value = 0

    case 0xfc:
        data := make([]byte, 2)
        _, err = r.Read(data)
        if err != nil {
            return
        }
        value = uint64(data[0]) | uint64(data[1])<<8

    case 0xfd:
        data := make([]byte, 3)
        _, err = r.Read(data)
        if err != nil {
            return
        }
        value = uint64(data[0]) | uint64(data[1])<<8 | uint64(data[2])<<16

    case 0xfe:
        data := make([]byte, 8)
        _, err = r.Read(data)
        if err != nil {
            return
        }
        value = uint64(data[0]) | uint64(data[1])<<8 | uint64(data[2])<<16 |
            uint64(data[3])<<24 | uint64(data[4])<<32 | uint64(data[5])<<40 |
            uint64(data[6])<<48 | uint64(data[7])<<56

    default:
        value = uint64(firstLenEncIntByte)
    }

    return value, err
}

// ReadLenEncodedString returns parsed length-encoded string and it's length.
// Length-encoded strings are prefixed by a length-encoded integer which describes
// the length of the string, followed by the string value.
// See https://mariadb.com/kb/en/mariadb/protocol-data-types/#length-encoded-strings
func ReadLenEncodedString(r *bytes.Reader) (string, uint64, error) {
    strLen, _ := ReadLenEncodedInteger(r)

    strBuf := make([]byte, strLen)
    if _, err := r.Read(strBuf); err != nil {
        return "", 0, err
    }

    return string(strBuf), strLen, nil
}

// ReadEOFLengthString returns parsed EOF-length string.
// EOF-length strings are those strings whose length will be calculated by the packet remaining length.
// See https://mariadb.com/kb/en/mariadb/protocol-data-types/#end-of-file-length-strings
func ReadEOFLengthString(data []byte) string {
    return string(data)
}

// ReadNullTerminatedString reads bytes from reader until 0x00 byte
// See https://mariadb.com/kb/en/mariadb/protocol-data-types/#null-terminated-strings
func ReadNullTerminatedString(r *bytes.Reader) string {
    var str []byte
    for {
        //TODO: Check for error
        b, _ := r.ReadByte()

        if b == 0x00 {
            return string(str)
        }

        str = append(str, b)
    }
}

// AuthSwitchRequest represents a request from the server to switch to a different authentication method.
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
type AuthSwitchRequest struct {
    SequenceNumber uint8
    PluginName     string
    PluginData     []byte
}

// UnpackAuthSwitchRequest decodes an AuthSwitchRequest packet from the provided data.
func UnpackAuthSwitchRequest(data []byte) (*AuthSwitchRequest, error) {
    sequenceNumber := data[3]
    data = data[4:]

    // Find the position of the null-terminated plugin name
    nullTerminatorIndex := 1
    for i := 1; i < len(data); i++ {
        if data[i] == 0x00 {
            nullTerminatorIndex = i
            break
        }
    }

    // Extract the plugin name
    if nullTerminatorIndex == 1 {
        return nil, fmt.Errorf("Invalid AuthSwitchRequest packet: Missing plugin name")
    }
    pluginName := string(data[1:nullTerminatorIndex])

    // Extract the plugin provided data
    var pluginData []byte
    if nullTerminatorIndex+1 < len(data) {
        pluginData = data[nullTerminatorIndex+1:]
    }

    return &AuthSwitchRequest{
        SequenceNumber: sequenceNumber,
        PluginName:     pluginName,
        PluginData:     pluginData,
    }, nil
}

// UnpackAuthRequestPubKeyResponse decodes a response from the server to a request for its public key.
func UnpackAuthRequestPubKeyResponse(data []byte) (*rsa.PublicKey, error) {
    // Parse public key
    if data[4] != ResponseAuthMoreData {
        return nil, fmt.Errorf("expected ResponseAuthMoreData packet")
    }

    block, rest := pem.Decode(data[5:])
    if block == nil {
        return nil, fmt.Errorf("no pem data found, data: %s", rest)
    }
    pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
    if err != nil {
        return nil, fmt.Errorf("failed to parse public key: %s", err)
    }
    return pkix.(*rsa.PublicKey), nil
}

// ReadNullTerminatedBytes reads bytes from reader until 0x00 byte
func ReadNullTerminatedBytes(r *bytes.Reader) (str []byte) {
    for {
        //TODO: Check for error
        b, _ := r.ReadByte()

        if b == 0x00 {
            return
        }

        str = append(str, b)
    }
}

// GetPacketHeader rewinds reader to packet payload
func GetPacketHeader(r *bytes.Reader) (s []byte, e error) {
    s = make([]byte, 4)

    if _, e = r.Read(s); e != nil {
        return nil, e
    }

    return
}

// CheckPacketLength checks if packet length meets expected value
func CheckPacketLength(expected int, packet []byte) error {
    if len(packet) < expected {
        return ErrInvalidPacketLength
    }

    return nil
}

// NativePassword calculates native password expected by server in HandshakeResponse41
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41
// SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat> SHA1( SHA1( password ) ) )
func NativePassword(password []byte, salt []byte) (nativePassword []byte, err error) {
    sha1 := sha1.New()
    sha1.Write(password)
    passwordSHA1 := sha1.Sum(nil)

    sha1.Reset()
    sha1.Write(passwordSHA1)
    hash := sha1.Sum(nil)

    sha1.Reset()
    sha1.Write(salt)
    sha1.Write(hash)
    randomSHA1 := sha1.Sum(nil)

    // nativePassword = passwordSHA1 ^ randomSHA1
    nativePassword = make([]byte, len(randomSHA1))
    for i := range randomSHA1 {
        nativePassword[i] = passwordSHA1[i] ^ randomSHA1[i]
    }

    return
}

// AddHeaderToPacket adds a header to a packet
func AddHeaderToPacket(sequenceID uint8, restOfPacket []byte) []byte {
    // Calculate the packet length (excluding the length field itself)
    packetLength := len(restOfPacket)

    // Create a header buffer and write the packet length (int<3>)
    headerBuffer := make([]byte, 4)
    headerBuffer[0] = byte(packetLength & 0xFF)
    headerBuffer[1] = byte((packetLength >> 8) & 0xFF)
    headerBuffer[2] = byte((packetLength >> 16) & 0xFF)
    headerBuffer[3] = sequenceID

    // Combine the header and packet data to create the final packet
    return append(headerBuffer, restOfPacket...)
}

// PackAuthSwitchResponse creates an AuthSwitchResponse packet with the provided response data.
func PackAuthSwitchResponse(authSwitchRequestSequenceID uint8, data []byte) ([]byte, error) {
    // Create a buffer to write the packet data
    buffer := new(bytes.Buffer)

    // Write the response data to the buffer
    buffer.Write(data)

    return AddHeaderToPacket(authSwitchRequestSequenceID, buffer.Bytes()), nil
}

// AuthMoreDataResponse represents a packet sent from the server to request more auth data from the client.
type AuthMoreDataResponse struct {
    SequenceID uint8
    PacketType byte
    StatusTag  byte
}

// UnpackAuthMoreDataResponse decodes AuthMoreData from server.
// Basic packet structure shown below.
//
// int<3> PacketLength
// int<1> PacketNumber
// int<1> PacketType (0x01)
// int<1> StatusTag (0x03 or 0x04)
// string<EOF> AuthenticationMethodData (unused by secretless)
func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) {

    // Min packet length = header(4 bytes) + PacketType(1 byte)
    if err := CheckPacketLength(5, packet); err != nil {
        return nil, err
    }

    r := bytes.NewReader(packet)

    header, err := GetPacketHeader(r)
    if err != nil {
        return nil, err
    }

    // Read header, validate OK
    packetType, err := r.ReadByte()
    if err != nil {
        return nil, err
    }
    if packetType != ResponseAuthMoreData {
        return nil, errors.New("Malformed packet")
    }

    // Read status tag
    statusTag, err := r.ReadByte()
    if err != nil {
        return nil, err
    }

    return &AuthMoreDataResponse{
        SequenceID: header[3],
        PacketType: packetType,
        StatusTag:  statusTag,
    }, nil
}

// PackAuthRequestPubKeyResponse encodes the request for the server's public key
func PackAuthRequestPubKeyResponse(sequenceID uint8) []byte {
    return AddHeaderToPacket(sequenceID, []byte{CachingSha2PasswordRequestPublicKey})
}

// PackAuthEncryptedPasswordResponse encodes the encrypted password response packet
func PackAuthEncryptedPasswordResponse(sequenceID uint8, encPwd []byte) []byte {
    return AddHeaderToPacket(sequenceID, encPwd)
}

// EncryptPassword encrypts a password using the provided seed and public key.
func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
    // For this stage of the authentication, we must use sha1. See
    // https://github.com/go-sql-driver/mysql/blob/19171b59bf90e6bf7a5bdf979e5e24a84b328b8a/auth.go#L217-L226
    plain := make([]byte, len(password)+1)
    copy(plain, password)
    for i := range plain {
        j := i % len(seed)
        plain[i] ^= seed[j]
    }
    sha1 := sha1.New()
    return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}