flaw-ffi/Flaw/FFI/COM/TH.hs
{-|
Module: Flaw.FFI.COM.TH
Description: Generating declarations for Windows COM interfaces.
License: MIT
-}
{-# LANGUAGE TemplateHaskell #-}
module Flaw.FFI.COM.TH
( genCOMInterface
) where
import Data.Maybe
import qualified Data.UUID as UUID
import Foreign.Ptr
import Foreign.Storable
import Language.Haskell.TH
import Flaw.FFI.COM.Internal
-- | Size of pointer on target architecture.
ptrSize :: Int
ptrSize = sizeOf (undefined :: Ptr ())
-- | Internal info about method.
data Method = Method
{ methodNameStr :: String
, methodName :: Name
, methodType :: TypeQ
, methodFieldName :: Name
, methodOffset :: ExpQ
, methodEnd :: ExpQ
, methodMake :: ExpQ
, methodTopDecs :: DecsQ
, methodClassDecs :: Name -> ExpQ -> [DecQ]
, methodField :: VarStrictTypeQ
}
processMethod :: String -> TypeQ -> String -> ExpQ -> Q Method
processMethod interfaceName mt mn prevEndExp = let
baseName = interfaceName ++ "_" ++ mn
name = mkName $ "m_" ++ baseName
functionTypeName = mkName $ "Mft_" ++ baseName
foreignFunctionTypeName = mkName $ "Mfft_" ++ baseName
fieldName = mkName $ "mp_" ++ baseName
offsetName = mkName $ "method_offset_" ++ baseName
makeName = mkName $ "mk_" ++ baseName
topDecs = sequence
[ sigD offsetName [t| Int |]
, valD (varP offsetName) (normalB prevEndExp) []
, tySynD functionTypeName [] mt
, tySynD foreignFunctionTypeName [] [t| Ptr $(conT $ mkName interfaceName) -> $(conT functionTypeName) |]
, forImpD stdCall safe "dynamic" makeName [t| FunPtr $(conT foreignFunctionTypeName) -> $(conT foreignFunctionTypeName) |]
]
classDecs paramName comGetExp =
[ sigD name [t| $(varT paramName) -> $(conT functionTypeName) |]
, valD (varP name) (normalB [| $(varE fieldName) . $comGetExp |]) []
]
in return Method
{ methodNameStr = mn
, methodName = name
, methodType = mt
, methodFieldName = fieldName
, methodOffset = varE offsetName
, methodEnd = [| $(varE offsetName) + ptrSize |]
, methodMake = varE makeName
, methodTopDecs = topDecs
, methodClassDecs = classDecs
, methodField = return (fieldName, Bang NoSourceUnpackedness NoSourceStrictness, ConT functionTypeName)
}
processMethods :: String -> ExpQ -> [(TypeQ, String)] -> Q [Method]
processMethods interfaceName = pm where
pm prevEndExp ((mt, mn) : nms) = do
m <- processMethod interfaceName mt mn prevEndExp
nextMethods <- pm (methodEnd m) nms
return $ m : nextMethods
pm _ [] = return []
-- | Create list of interface names with all parents.
getInterfaceChain :: [String] -> Q [String]
getInterfaceChain [] = return []
getInterfaceChain [nameStr] = do
maybeParentFieldName <- lookupValueName $ "pd_" ++ nameStr
case maybeParentFieldName of
Just parentFieldName -> do
VarI _ (AppT (AppT ArrowT _) (ConT parentName)) _ <- reify parentFieldName
parents <- getInterfaceChain [nameBase parentName]
return $ nameStr : parents
Nothing -> return [nameStr]
getInterfaceChain (n:ns) = do
rns <- getInterfaceChain ns
return $ n:rns
-- | Generate necessary things for COM interface.
-- In order to appropriately export interface, let say, (genCOMInterface "IMyInterface" ...),
-- you need to export: MyModule (IMyInterface(..), IMyInterface_Classes(..)).
-- List of parent interfaces may be not full (like contain only immediate parent),
-- the method will try to find all parents by reification. The possibility to
-- manually specify all interfaces is useful when stage restriction is in place
-- (i.e. reify doesn't work).
genCOMInterface
:: String -- ^ Name of the interface.
-> String -- ^ String representation of IID (interface GUID).
-> [String] -- ^ Names of parent interfaces, from immediate to root.
-> [(TypeQ, String)] -- ^ List of methods. Type of method should be without 'this' argument.
-> Q [Dec]
genCOMInterface interfaceNameStr iid parentInterfaceNames ms = do
let
interfaceName = mkName interfaceNameStr
iidName = mkName $ "iid_" ++ interfaceNameStr
endName = mkName $ "ie_" ++ interfaceNameStr
parentFieldName = mkName $ "pd_" ++ interfaceNameStr
thisName = mkName $ "it_" ++ interfaceNameStr
thisParamName <- newName "this"
vtParamName <- newName "vt"
(beginExp, parentFields, peekParentBinds, parentFieldsConstr, parentInstanceDecs) <- case parentInterfaceNames of
(firstParentInterfaceNameStr:_) -> do
parentParamName <- newName "parent"
let
parentDec parentInterfaceNameStr = let
parentInterfaceClassName = mkName $ parentInterfaceNameStr ++ "_Class"
comGetParent = mkName $ "com_get_" ++ parentInterfaceNameStr
in instanceD (return []) [t| $(conT parentInterfaceClassName) $(conT interfaceName) |]
[funD comGetParent [clause [recP interfaceName [return (parentFieldName, VarP parentParamName)]] (normalB [| $(varE comGetParent) $(varE parentParamName) |]) []]]
parentDecs <- mapM parentDec =<< getInterfaceChain parentInterfaceNames
return
( [| sizeOfCOMVirtualTable (undefined :: $(conT $ mkName firstParentInterfaceNameStr)) |]
, [return (parentFieldName, Bang NoSourceUnpackedness NoSourceStrictness, ConT $ mkName firstParentInterfaceNameStr)]
, [bindS (varP parentParamName) [| peekCOMVirtualTable (castPtr $(varE thisParamName)) $(varE vtParamName) |]]
, [return (parentFieldName, VarE parentParamName)]
, parentDecs
)
[] -> return ([| 0 |], [], [], [], [])
methods <- processMethods interfaceNameStr beginExp ms
dataDec <- dataD (return []) interfaceName [] Nothing [recC interfaceName $ return (thisName, Bang NoSourceUnpackedness NoSourceStrictness, AppT (ConT ''Ptr) $ ConT interfaceName) : parentFields ++ map methodField methods] []
-- instance COMInterface IInterface
let
mp1 method = do
p <- newName $ "p_" ++ methodNameStr method
let binding = bindS (varP p) [| peek (plusPtr $(varE vtParamName) $(methodOffset method)) |]
makeExp <- [| $(methodMake method) (castPtrToFunPtr $(varE p)) $(varE thisParamName) |]
let field = (methodFieldName method, makeExp)
return (binding, field)
mp1s <- mapM mp1 methods
let
peekCOMVirtualTableDec = funD 'peekCOMVirtualTable [clause [varP thisParamName, varP vtParamName] body []] where
body = normalB $ doE $ peekParentBinds ++ map fst mp1s ++ [noBindS $ appE (varE 'return) $ recConE interfaceName $ return (thisName, VarE thisParamName) : parentFieldsConstr ++ map (return . snd) mp1s]
comInterfaceInstanceDec <- instanceD (return []) [t| COMInterface $(conT interfaceName) |]
[ funD 'getIID [clause [wildP] (normalB $ varE iidName) []]
, funD 'getCOMInterfaceName [clause [wildP] (normalB $ litE $ StringL interfaceNameStr) []]
, funD 'sizeOfCOMVirtualTable [clause [wildP] (normalB $ varE endName) []]
, funD 'pokeCOMObject [clause [recP interfaceName [return (thisName, VarP thisParamName)]] (normalB $ varE thisParamName) []]
, peekCOMVirtualTableDec
]
aParam <- newName "a"
bParam <- newName "b"
-- instance Eq IInterface
eqInterfaceInstanceDec <- instanceD (return []) [t| Eq $(conT interfaceName) |]
[ funD '(==) [clause [recP interfaceName [return (thisName, VarP aParam)], recP interfaceName [return (thisName, VarP bParam)]] (normalB [| $(varE aParam) == $(varE bParam) |]) []]
]
-- instance Ord IInterface
ordInterfaceInstanceDec <- instanceD (return []) [t| Ord $(conT interfaceName) |]
[ funD 'compare [clause [recP interfaceName [return (thisName, VarP aParam)], recP interfaceName [return (thisName, VarP bParam)]] (normalB [| compare $(varE aParam) $(varE bParam) |]) []]
, funD '(<=) [clause [recP interfaceName [return (thisName, VarP aParam)], recP interfaceName [return (thisName, VarP bParam)]] (normalB [| $(varE aParam) <= $(varE bParam) |]) []]
]
let
className = mkName $ interfaceNameStr ++ "_Class"
comGetName = mkName $ "com_get_" ++ interfaceNameStr
paramName <- newName "a"
-- class IInterface_Class a
classDec <- classD (sequence [ [t| COMInterface $(varT paramName) |] ]) className [PlainTV paramName] []
(sigD comGetName [t| $(varT paramName) -> $(conT interfaceName) |] : concatMap (\method -> methodClassDecs method paramName $ varE comGetName) methods)
-- instance IInterface_Class IInterface
instanceDec <- instanceD (return []) [t| $(conT className) $(conT interfaceName) |]
[ valD (varP comGetName) (normalB [| id |]) []
]
endSigDec <- sigD endName [t| Int |]
endValDec <- valD (varP endName) (normalB [| $beginExp + $(litE $ integerL $ fromIntegral $ length methods) * ptrSize |]) []
iidSigDec <- sigD iidName [t| IID |]
iidValDec <- valD (varP iidName) (normalB [| fromJust (UUID.fromString $(litE $ stringL iid)) |]) []
methodsTopDecs <- mapM methodTopDecs methods
return $ dataDec : comInterfaceInstanceDec : eqInterfaceInstanceDec : ordInterfaceInstanceDec : classDec : instanceDec : endSigDec : endValDec : iidSigDec : iidValDec : concat methodsTopDecs ++ parentInstanceDecs