«

»

31

WDK驱动开发使用C++模板实现动态存储类CLinkList

WDK支持C++编译,但大多数人写驱动还是习惯用纯C甚至C/C++滥用…最近写一个小项目,驱动代码我则尝试了完全使用C++来实现,之后我发现,即便是驱动开发,C++也要比纯C语言方便得多。虽然内核里没有STL,但是WDK仍支持C++的很多特性,类、重载、模板,以及多态的很多特性,都可以正常使用(当然..我也只是个C++初学者,很多高级功能并没有尝试)。

在内核编程时,很多时候我们要遍历一些内核里的数据结构并做记录,而这些数据的大小通常是动态的,我们在开始遍历时并不知道储存他们的记录要占用多大空间。面对这种问题,最简单的有两个解决方法:

1.分配一个足够大的内存,保证能存储下数据(大多数情况下会浪费很多空间)

2.预先遍历一遍并记录总大小,然后再遍历一次进行记录(毫无疑问地效率低)

以上两种简单的方法都有不可忽视的缺陷,所以我们要用稍高级一些的数据结构来解决这个问题,针对此问题的一个“完美”的解决方案就是动态链表。因为动态链表是不受空间限制的,而且随用随分配,非常好使。于是我建立了一个动态链表来保存我遍历到的数据,效果的确令我满意。但是,一天之后问题就来了,我又要遍历一类新的数据,也不得不采用链表这种动态存储方案。也就是说我又要把相似的代码重写一遍,而且链表的各种操作还不是闭着眼睛就能写好的。这样做无疑是非常不可取的。学习C++之前我对这个问题无可奈何,但现在我掌握了C++的模板机制,使用模板正好可以解决这个代码冗余的问题。于是我便写了一个内核下的动态存储类,方便记录各种类型的数据。下面把代码分享一下,并用一个简单地示例来演示它的使用方法。不过在这之前,要先介绍一下内核编程中的动态内存分配。

 

WDK下C++的动态内存分配

因为内核里动态分配的内存有分页和未分页的区别,所以new/delete操作符稍有不同,通常我们可以自己实现一组简易的new/delete:

头文件:

//stdafx.h

 

//全局new操作符  
void * __cdecl operator new(size_t size,POOL_TYPE PoolType);
//全局delete操作符  
void __cdecl operator delete(void* pointer);

 源文件:

//stdafx.cpp

 

#include "stdafx.h"

// 全局new操作符  
void * __cdecl operator new(size_t size, POOL_TYPE PoolType)
{
    return ExAllocatePool(PoolType, size);
}

// 全局delete操作符  
void __cdecl operator delete(void* pointer)
{
    ExFreePool(pointer);
}

 new使用时跟平常的new略有区别(delete无变化),形式如下:

int *buf = new(NonPagedPool) int;

另外要注意stdafx.h最好不要被重复包含,不然链接时可能会出问题。

 

CLinkList代码实现

接下来就是CLinkList存储类的实现了,这里还写了一个模板CListEntry,是为了用它来建立不同数据类型的链表。

其实这玩意的作用跟STL的vector类似,但是还远不能跟vector媲美,比如没有迭代器..代码考虑也不够周全,这些问题以后我应该还会改善的,等把STL学明白了,写一个内核下的KTL也未尝不可~就目前来讲,CLinkList对我来说已经够用了。

注意下面这段代码是全部写在CLinkList.h里面的,标准来讲代码的实现应该写在.cpp里,跟声明分开。但是WDK对C++模板的支持还是没有那么全面,如果那样写的话,会出现链接错误。如果一定要分成CLinkList.h和CLinkList.cpp的话,那么别的.cpp要使用CLinkList,就必须#include "CLinkList.cpp"而不是#include "CLinkList.h"。所以与其那样还不如直接将所有代码都写到一个.h里,直接包含就可以使用了~

WDK还有一点问题,就是除了模板类的构造函数外,其他成员函数里不能对类型参数使用sizeof()运算。比如有template<T> class A,我们在A::A(T Target)里使用sizeof(T)或sizeof(Target)是正确的,但是在其他成员函数里就都会产生编译错误。所以要获取T类型的大小字节数,最好是在构造函数里记录一下。当然了,C++的标准是允许在成员函数里sizeof(T)的,这里出现这个问题是编译器的BUG。

//CLinkList.h

 

#ifndef CLINKLIST_INCLUDE
#define CLINKLIST_INCLUDE

#include "stdafx.h"

template <typename T> struct CListEntry {
    LIST_ENTRY List;
    T Data;
};

template <typename T> class CLinkList {
private:
    unsigned long m_Total;
    long m_MaxSize;

    CListEntry<T> *m_lpLinkHead;
    
    bool Insert(const T & Target, bool IsLinkHead);

public:
    CLinkList();
    ~CLinkList();
    T Record(const unsigned long Index);
    T operator[](const unsigned long Index);
    T Top(void);
    inline long Size() {
        return m_Total;
    };
    inline bool IsEmpty() {
        return (m_Total == 0? true: false);
    };
    bool InsertHead(const T & Target);
    bool InsertTail(const T & Target);
    bool push(const T & Target);
    T pop(void);
    bool IsExist(const T & Target);
    T Remove(const unsigned long Index);
    T Remove(const T & Target);
    void Clear();
};

template <typename T> CLinkList<T>::CLinkList()
{
    m_Total = 0;

    //
    // 记录类型的长度不能为负
    //
    m_MaxSize = sizeof(T);
    if (m_MaxSize <= 0)
        KeBugCheck(STATUS_UNSUCCESSFUL);
}

template <typename T> CLinkList<T>::~CLinkList()
{
    //
    // 析构时要清理已建立的链表
    //
    Clear();
}

template <typename T> T CLinkList<T>::Record(const unsigned long Index)
{
    //
    // 索引不能超出范围
    //
    if (Index > m_Total – 1)
        KeBugCheck(STATUS_UNSUCCESSFUL);

    //
    // 指针后移Index次
    //
    CListEntry<T> *lpLinkCurrent = m_lpLinkHead;
    for (int i = 0; i < Index; i++)
        lpLinkCurrent = (CListEntry<T>*)lpLinkCurrent->List.Flink;

    //
    // 返回内容
    //
    return lpLinkCurrent->Data;
}

template <typename T> T CLinkList<T>::operator[](const unsigned long Index)
{
    //
    // 重载[]以便代替Record()
    //
    return Record(Index);
}

template <typename T> T CLinkList<T>::Top(void)
{
    //
    // 返回链首的记录
    //
    return Record(Size() – 1);
}

template <typename T> bool CLinkList<T>::Insert(const T & Target, bool IsLinkHead)
{
    if (IsExist(Target))
        return false;

    //
    // 分配缓冲区。缓冲区首先是一个LIST_ENTRY,紧接着便是用户传入的结构
    //
    long Size = sizeof(LIST_ENTRY) + m_MaxSize;
    CListEntry<T> *lpLinkCurrent = new(NonPagedPool) CListEntry<T>();
    lpLinkCurrent->Data = Target;

    //
    // 将该记录插入链尾
    //
    if (m_Total == 0) {
        m_lpLinkHead = lpLinkCurrent;
        lpLinkCurrent->List.Blink = lpLinkCurrent->List.Flink = (PLIST_ENTRY)m_lpLinkHead;
    } else {
        //
        // 将新的Target插入链表: … <-> LinkTail <-> LinkCurrent <-> LinkHead <-> …
        //

        // LinkCurrent与LinkTail相连
        m_lpLinkHead->List.Blink->Flink = (PLIST_ENTRY)lpLinkCurrent;
        lpLinkCurrent->List.Blink = m_lpLinkHead->List.Blink;

        // LinkCurrent与LinkHead相连
        m_lpLinkHead->List.Blink = (PLIST_ENTRY)lpLinkCurrent;
        lpLinkCurrent->List.Flink = (PLIST_ENTRY)m_lpLinkHead;
    }

    //
    // 如果指定要插入链首,则将该记录设为新的链首
    //
    if (IsLinkHead == true)
        m_lpLinkHead = lpLinkCurrent;

    //
    // 记录总数+1
    //
    m_Total++;

    return true;
}

template <typename T> bool CLinkList<T>::InsertHead(const T & Target)
{
    //
    // 在链首插入记录
    //
    return Insert(Target, true);
}

template <typename T> bool CLinkList<T>::InsertTail(const T & Target)
{
    //
    // 在链尾插入记录
    //
    return Insert(Target, false);
}

template <typename T> bool CLinkList<T>::push(const T & Target)
{
    //
    // 模拟堆的push操作
    //
    return InsertTail(Target);
}

template <typename T> T CLinkList<T>::pop(void)
{
    //
    // 模拟堆的pop操作
    //
    return Remove(Size() – 1);
}

template <typename T> bool CLinkList<T>::IsExist(const T & Target)
{
    //
    // 若记录链为空则不存在任何记录
    //
    if (IsEmpty())
        return false;

    CListEntry<T> *lpLinkCurrent = m_lpLinkHead;
    do {
        //
        // 用Target与链表中保存的记录进行对比,如果内容一致则该记录存在
        //
        if (RtlCompareMemory(&lpLinkCurrent->Data, &Target, sizeof(Target)) == sizeof(Target))
            return true;
        lpLinkCurrent = (CListEntry<T>*)lpLinkCurrent->List.Flink;
    } while (lpLinkCurrent != m_lpLinkHead);

    return false;
}

template <typename T> T CLinkList<T>::Remove(const unsigned long Index)
{
    return Remove(Record(Index));
}

template <typename T> T CLinkList<T>::Remove(const T & Target)
{
    //
    // 若记录链为空则不存在任何记录
    //
    if (IsEmpty())
        KeBugCheck(STATUS_UNSUCCESSFUL);

    CListEntry<T> *lpLinkCurrent = m_lpLinkHead;
    do {
        if (RtlCompareMemory(&lpLinkCurrent->Data, &Target, sizeof(Target)) == sizeof(Target)) {
            //
            // 先复制该项内容以便返回结果
            //
            T Result;
            Result = lpLinkCurrent->Data;

            //
            // 若该项是表头,则设置下一项为新的表头
            //
            if (lpLinkCurrent == m_lpLinkHead) {
                m_lpLinkHead = (CListEntry<T>*)lpLinkCurrent->List.Flink;
            }

            //
            // 将该项脱链并释放内存
            //
            lpLinkCurrent->List.Blink->Flink = lpLinkCurrent->List.Flink;
            lpLinkCurrent->List.Flink->Blink = lpLinkCurrent->List.Blink;
            ExFreePool(lpLinkCurrent);

            //
            // 记录总数-1
            //
            m_Total–;

            return Result;
        }
        lpLinkCurrent = (CListEntry<T>*)lpLinkCurrent->List.Flink;
    } while (lpLinkCurrent != m_lpLinkHead);

    //
    // 不能试图删除不存在的记录
    //
    KeBugCheck(STATUS_UNSUCCESSFUL);
    return Target;
}

template <typename T> void CLinkList<T>::Clear()
{
    //
    // 若记录链为空则不存在任何记录,无须清理
    //
    if (IsEmpty())
        return;

    //
    // 逐一释放每个记录块
    //
    CListEntry<T> *lpLinkCurrent = m_lpLinkHead, *lpLinkNext;
    do {
        lpLinkNext = (CListEntry<T>*)lpLinkCurrent->List.Flink;
        ExFreePool(lpLinkCurrent);
        lpLinkCurrent = lpLinkNext;
    } while (lpLinkCurrent != m_lpLinkHead);

    //
    // 链头指针设为空指针,记录总数设为零
    //
    m_lpLinkHead = NULL;
    m_Total = 0;
}

#endif

 

一个简单的使用示例

上面的是实现代码,下面我写了一个通过ActiveProcessLink链表遍历进程并使用CLinkList动态存储的例子来展示它的使用方法。代码非常简单,CLinkList的使用也非常简单~

//CLinkListDemo.cpp

 

#include "stdafx.h"
#include "CLinkList.h"

//#define DEBUG

struct PROCESS_INFO {
    PVOID EProcess;
    HANDLE Pid;
    UINT8 ImageName[16];
};

void DriverUnload(IN PDRIVER_OBJECT DriverObject) {};
void LinkListTest(CLinkList<PROCESS_INFO> & ProcessRecord);
void QuerySystemProcessInformations(CLinkList<PROCESS_INFO> * ProcessRecord);
void PrintAll(CLinkList<PROCESS_INFO> & ProcessRecord);

#ifdef __cplusplus
extern "C" NTSTATUS DriverEntry(IN PDRIVER_OBJECT DriverObject, IN PUNICODE_STRING  RegistryPath);
#endif

NTSTATUS DriverEntry(IN PDRIVER_OBJECT DriverObject, IN PUNICODE_STRING  RegistryPath)
{
    DriverObject->DriverUnload = DriverUnload;

    //
    // 声明一个PROCESS_INFO类型的模板类
    //
    CLinkList<PROCESS_INFO> ProcessRecord;

    //
    // DEBUG模式下反复记录100次,测试代码的稳定性
    //
#ifdef DEBUG
    for (int i = 0; i < 100; i++) {
#endif
        LinkListTest(ProcessRecord);
#ifdef DEBUG
    }
#endif

    return STATUS_SUCCESS;
}

void LinkListTest(CLinkList<PROCESS_INFO> & ProcessRecord)
{
    //
    // 枚举进程并动态记录
    //
    QuerySystemProcessInformations(&ProcessRecord);

    //
    // 输出所有进程信息
    //
    PrintAll(ProcessRecord);

    //
    // 删掉第二个错误的记录后,再次输出
    //
    ProcessRecord.Remove(1);
    PrintAll(ProcessRecord);

    //
    // 对比最后一个记录删除前后IsExist()的结果
    //
    PROCESS_INFO TestProcess = ProcessRecord[ProcessRecord.Size() – 1];
    if (ProcessRecord.IsExist(TestProcess))
        KdPrint(("Process %s exist.", TestProcess.ImageName));
    else
        KdPrint(("Process %s not found.", TestProcess.ImageName));
    
    ProcessRecord.pop();
    KdPrint(("Process %s record removed.\n", TestProcess.ImageName));

    if (ProcessRecord.IsExist(TestProcess))
        KdPrint(("Process %s exist.", TestProcess.ImageName));
    else
        KdPrint(("Process %s not found.", TestProcess.ImageName));
    KdPrint(("\n"));

    //
    // 清空内容后,再次尝试输出
    //
    ProcessRecord.Clear();
    KdPrint(("All records cleared."));
    PrintAll(ProcessRecord);
}

void QuerySystemProcessInformations(CLinkList<PROCESS_INFO> * ProcessRecord)
{
    //
    // 从当前进程(System)开始枚举
    //
    PCHAR Process = (PCHAR)PsGetCurrentProcess();

    LIST_ENTRY *lpLinkHead, *lpLinkCurrent;
    lpLinkHead = lpLinkCurrent = (LIST_ENTRY*)(Process + 0x88);
    do {
        //
        // 拷贝进程信息
        //
        PROCESS_INFO ProcessInfo;
        ProcessInfo.EProcess = (PVOID)((PCHAR)lpLinkCurrent – 0x88);
        ProcessInfo.Pid = *(HANDLE*)((PCHAR)ProcessInfo.EProcess + 0x84);
        RtlCopyMemory(ProcessInfo.ImageName, (PCHAR)ProcessInfo.EProcess + 0x174, 16 * sizeof(UINT8));

        //
        // 添加记录
        //
        if (!((int)ProcessInfo.Pid % 4))
            ProcessRecord->push(ProcessInfo);

        lpLinkCurrent = lpLinkCurrent->Blink;
    } while (lpLinkCurrent != lpLinkHead);
}

void PrintAll(CLinkList<PROCESS_INFO> & ProcessRecord)
{
    //
    // 输出所有进程记录
    //
    KdPrint(("%d Active Process Records: ", ProcessRecord.Size()));
    for (int i = 0; i != ProcessRecord.Size(); i++)
        KdPrint(("[%d] %s", ProcessRecord[i].Pid, ProcessRecord[i].ImageName));
    KdPrint(("\n"));
}

 

1 comment

  1. TL_GTASA

    好像很流弊的样子……

    [回复]

发表评论

电子邮件地址不会被公开。 必填项已用*标注

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>