#include        "defs.h"

void    imageAdd(__in PVOID  ImageBase){
        pimage_struct   pimage;
        PIMAGE_DOS_HEADER       pmz;
        PPEHEADER32             pe32;
        size_t                  len;
        char                    *as;
        WCHAR                   wsImageName[MAX_PATH * 4];
        WCHAR                   *unicode;
        PUNICODE_STRING         pus;
        NTSTATUS                status;
        ULONG                   cbNeeded;
        MEMORY_BASIC_INFORMATION mbi;
        ULONG                   memSize;
        unsigned long           index;
        
        pimage = dlmalloc(sizeof(image_struct));
        
        memset(pimage, 0, sizeof(image_struct));
        memset(wsImageName, 0, sizeof(wsImageName));
        
        status = NtQueryVirtualMemory((HANDLE)(ULONG_PTR)-1, 
                                      (PVOID)ImageBase,
                                      MemoryMappedFilenameInformation,
                                      wsImageName,
                                      sizeof(wsImageName),
                                      &cbNeeded);
                                      
        
        pmz = (PIMAGE_DOS_HEADER)ImageBase;
        pe32= (PPEHEADER32)((ULONG_PTR)ImageBase + pmz->e_lfanew);
                
        pimage->image_start = (ULONG_PTR)ImageBase;
        pimage->image_end   = (ULONG_PTR)ImageBase + pe32->pe_sizeofimage;
        len = wcslen(wsImageName);
        len++;
        
        pimage->image_name  = dlmalloc(len);
        memset(pimage->image_name, 0, len);
        
        as      = pimage->image_name;
        pus     = (PUNICODE_STRING)wsImageName;
        unicode = pus->Buffer;
         
        while (*unicode){
                *as = (char)*unicode;
                if (*as >= 'A' && *as <= 'Z')
                        *as = *as + ('a' - 'A');
                as++;
                unicode++;        
        }
        
        DbgPrint(("%s -- adding image to the list : %s", __FUNCTION__, pimage->image_name));
        //DbgPrint(("%s -- image base               : %.08X", __FUNCTION__, pimage->image_start));
        //DbgPrint(("%s -- image size               : %.08X", __FUNCTION__, pimage->image_end - pimage->image_start));                
        InsertTailList(&image_list_head, (PLIST)pimage);
        //build vmmaps for this code... woohoooo.... and add it to the list
        //of allocated vmmaps
        index = pimage->image_start;
        while (index < pimage->image_end){
                ntVirtualQuery((void *)index, &mbi, sizeof(mbi));
                
                vmmapAdd((void *)mbi.BaseAddress, mbi.RegionSize, mbi.Protect);
                index += mbi.RegionSize;        
                
        }
}

void    imageRemove(__in PVOID ImageBase){
        pimage_struct     pimage;
        pimage_struct     pimage_next;
        
        pimage = (pimage_struct)image_list_head.Flink;
        
        while ((ULONG_PTR)pimage != (ULONG_PTR)&image_list_head){
                if (pimage->image_start == (ULONG_PTR)ImageBase){
                        pimage->Next.Blink->Flink = pimage->Next.Flink;
                        pimage->Next.Flink->Blink = pimage->Next.Blink;
                        DbgPrint(("%s -- removing image from a list : %s", __FUNCTION__, pimage->image_name));
                        dlfree(pimage->image_name);        
                        dlfree(pimage);
                        break;
                }        
                pimage = (pimage_struct)pimage->Next.Flink;
        }  
}


pimage_struct imageNext(__in pimage_struct pimage){
        if ((ULONG_PTR)pimage->Next.Flink == (ULONG_PTR)&image_list_head) return NULL;
        return (pimage_struct)pimage->Next.Flink;               
}

pimage_struct imageFirst(){
        if ((ULONG_PTR)image_list_head.Flink == (ULONG_PTR)&image_list_head) return NULL;
        return (pimage_struct)image_list_head.Flink;
}

pimage_struct imageFindImageForAddress(__in unsigned long addr){       
        pimage_struct   pimage;
          
        pimage = imageFirst();
        if (!pimage) return NULL;   
        
        if (addr >= pimage->image_start && addr <= pimage->image_end)
                return pimage;
        
        while (NULL != (pimage = imageNext(pimage))){
                if (addr >= pimage->image_start && addr < pimage->image_end)
                        return pimage;        
        }     
        return NULL;
}