#include <windows.h>
#include <shlwapi.h>
#include <tchar.h>
#include <winhttp.h>
#include "iucommon.h"
#include "logging.h"
#include "download.h"
#include "dlutil.h"
#include "malloc.h"

#include "wusafefn.h"

///////////////////////////////////////////////////////////////////////////////
// 

typedef BOOL  (WINAPI *pfn_OpenProcessToken)(HANDLE, DWORD, PHANDLE);
typedef BOOL  (WINAPI *pfn_OpenThreadToken)(HANDLE, DWORD, BOOL, PHANDLE);
typedef BOOL  (WINAPI *pfn_SetThreadToken)(PHANDLE, HANDLE);
typedef BOOL  (WINAPI *pfn_GetTokenInformation)(HANDLE, TOKEN_INFORMATION_CLASS, LPVOID, DWORD, PDWORD);
typedef BOOL  (WINAPI *pfn_IsValidSid)(PSID);
typedef BOOL  (WINAPI *pfn_AllocateAndInitializeSid)(PSID_IDENTIFIER_AUTHORITY, BYTE, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, PSID);
typedef BOOL  (WINAPI *pfn_EqualSid)(PSID, PSID);
typedef PVOID (WINAPI *pfn_FreeSid)(PSID);

const TCHAR c_szRPWU[]        = _T("Software\\Microsoft\\Windows\\CurrentVersion\\WindowsUpdate");
const TCHAR c_szRVTransport[] = _T("DownloadTransport");
const TCHAR c_szAdvapi32[]    = _T("advapi32.dll");

// ***************************************************************************
static
BOOL AmIPrivileged(void)
{
    LOG_Block("AmINotPrivileged()");

    pfn_AllocateAndInitializeSid    pfnAllocateAndInitializeSid = NULL;
    pfn_GetTokenInformation         pfnGetTokenInformation = NULL;
    pfn_OpenProcessToken            pfnOpenProcessToken = NULL;
    pfn_OpenThreadToken             pfnOpenThreadToken = NULL;
    pfn_SetThreadToken              pfnSetThreadToken = NULL;
    pfn_IsValidSid                  pfnIsValidSid = NULL;
    pfn_EqualSid                    pfnEqualSid = NULL;
    pfn_FreeSid                     pfnFreeSid = NULL;
    HMODULE                         hmod = NULL;
    
    SID_IDENTIFIER_AUTHORITY        siaNT = SECURITY_NT_AUTHORITY;
    TOKEN_USER                      *ptu = NULL;
    HANDLE                          hToken = NULL, hTokenImp = NULL;
    DWORD                           cb, cbGot, i;
    PSID                            psid = NULL;
    BOOL                            fRet = FALSE;

    DWORD                           rgRIDs[3] = { SECURITY_LOCAL_SYSTEM_RID,
                                                  SECURITY_LOCAL_SERVICE_RID,
                                                  SECURITY_NETWORK_SERVICE_RID };

    hmod = LoadLibraryFromSystemDir(c_szAdvapi32);
    if (hmod == NULL)
        goto done;

    pfnAllocateAndInitializeSid = (pfn_AllocateAndInitializeSid)GetProcAddress(hmod, "AllocateAndInitializeSid");
    pfnGetTokenInformation      = (pfn_GetTokenInformation)GetProcAddress(hmod, "GetTokenInformation");
    pfnOpenProcessToken         = (pfn_OpenProcessToken)GetProcAddress(hmod, "OpenProcessToken");
    pfnOpenThreadToken          = (pfn_OpenThreadToken)GetProcAddress(hmod, "OpenThreadToken");
    pfnSetThreadToken           = (pfn_SetThreadToken)GetProcAddress(hmod, "SetThreadToken");
    pfnIsValidSid               = (pfn_IsValidSid)GetProcAddress(hmod, "IsValidSid");
    pfnEqualSid                 = (pfn_EqualSid)GetProcAddress(hmod, "EqualSid");
    pfnFreeSid                  = (pfn_FreeSid)GetProcAddress(hmod, "FreeSid");
    if (pfnAllocateAndInitializeSid == NULL || 
        pfnGetTokenInformation == NULL || 
        pfnOpenProcessToken == NULL ||
        pfnOpenThreadToken == NULL ||
        pfnSetThreadToken == NULL ||
        pfnIsValidSid == NULL ||
        pfnEqualSid == NULL ||
        pfnFreeSid == NULL)
    {
        SetLastError(ERROR_PROC_NOT_FOUND);
        goto done;
    }

    // need the process token
    fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ, &hToken);
    if (fRet == FALSE)
    {
        if (GetLastError() == ERROR_ACCESS_DENIED)
        {
            fRet = (*pfnOpenThreadToken)(GetCurrentThread(), 
                                         TOKEN_READ | TOKEN_IMPERSONATE,
                                         TRUE, &hTokenImp);
            if (fRet == FALSE)
                goto done;

            fRet = (*pfnSetThreadToken)(NULL, NULL);

            fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ, 
                                          &hToken);
            if ((*pfnSetThreadToken)(NULL, hTokenImp) == FALSE)
                fRet = FALSE;
        }

        if (fRet == FALSE)
            goto done;
    }

    // need the SID from the token
    fRet = (*pfnGetTokenInformation)(hToken, TokenUser, NULL, 0, &cb);
    if (fRet != FALSE && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
    {
        fRet = FALSE;
        goto done;
    }

    ptu = (TOKEN_USER *)HeapAlloc(GetProcessHeap(), 0, cb);
    if (ptu == NULL)
    {
        SetLastError(ERROR_OUTOFMEMORY);
        fRet = FALSE;
        goto done;
    }

    fRet = (*pfnGetTokenInformation)(hToken, TokenUser, (LPVOID)ptu, cb, 
                                     &cbGot);
    if (fRet == FALSE)
        goto done;

    fRet = (*pfnIsValidSid)(ptu->User.Sid);
    if (fRet == FALSE)
        goto done;

    // loop thru & check against the SIDs we are interested in
    for (i = 0; i < 3; i++)
    {
        fRet = (*pfnAllocateAndInitializeSid)(&siaNT, 1, rgRIDs[i], 0, 0, 0, 
                                              0, 0, 0, 0, &psid);
        if (fRet == FALSE)
            goto done;

        fRet = (*pfnIsValidSid)(psid);
        if (fRet == FALSE)
            goto done;

        // if we get a SID match, then return TRUE
        fRet = (*pfnEqualSid)(psid, ptu->User.Sid);
        (*pfnFreeSid)(psid);
        psid = NULL;
        if (fRet)
        {
            fRet = TRUE;
            goto done;
        }
    }

    // only way to get here is to fail all the SID checks above.  So we ain't
    //  privileged.  Yeehaw.
    fRet = FALSE;
    
done:
    // if we had an impersonation token on the thread, put it back in place.
    if (ptu != NULL)
        HeapFree(GetProcessHeap(), 0, ptu);
    if (hToken != NULL)
        CloseHandle(hToken);
    if (hTokenImp != NULL)
        CloseHandle(hTokenImp);
    if (psid != NULL && pfnFreeSid != NULL)
        (*pfnFreeSid)(psid);
    if (hmod != NULL)
        FreeLibrary(hmod);

    return fRet;
}

#if defined(DEBUG) || defined(DBG)

// **************************************************************************
static
BOOL CheckDebugRegKey(DWORD *pdwAllowed)
{
    LOG_Block("CheckDebugRegKey()");

    DWORD   dw, dwType, dwValue, cb;
    HKEY    hkey = NULL;
    BOOL    fRet = FALSE;

    // explictly do not initialize *pdwAllowed.  We only want it overwritten
    //  if the reg key is properly set

    dw = RegOpenKeyEx(HKEY_LOCAL_MACHINE, c_szRPWU, 0, KEY_READ, &hkey);
    if (dw != ERROR_SUCCESS)
        goto done;

    cb = sizeof(dwValue);
    dw = RegQueryValueEx(hkey, c_szRVTransport, 0, &dwType, (LPBYTE)&dwValue, 
                         &cb);
    if (dw != ERROR_SUCCESS)
        goto done;

    // set this to 3 so we'll fall down into the error case below
    if (dwType != REG_DWORD)
        dwValue = 3;
    
    fRet = TRUE;

    switch(dwValue)
    {
        case 0:
            *pdwAllowed = 0;
            break;

        case 1:
            *pdwAllowed = WUDF_ALLOWWINHTTPONLY;
            break;

        case 2:
            *pdwAllowed = WUDF_ALLOWWININETONLY;
            break;

        default:
            LOG_Internet(_T("Bad reg value in DownloadTransport.  Ignoring."));
            fRet = FALSE;
            break;
    }

done:
    if (hkey != NULL)
        RegCloseKey(hkey);

    return fRet;
}

#endif

// **************************************************************************
DWORD GetAllowedDownloadTransport(DWORD dwFlagsInitial)
{
    DWORD   dwFlags = (dwFlagsInitial & WUDF_TRANSPORTMASK);

#if defined(UNICODE)
    // don't bother checking if we're local system if we're already using
    //  wininet
    if ((dwFlags & WUDF_ALLOWWININETONLY) == 0)
    {
        if (AmIPrivileged() == FALSE)
            dwFlags = WUDF_ALLOWWININETONLY;
    }

#if defined(DEBUG) || defined(DBG)
    CheckDebugRegKey(&dwFlags);
#endif // defined(DEBUG) || defined(DBG)

#else // defined(UNICODE)

    // only allow wininet on ANSI
    dwFlags = WUDF_ALLOWWININETONLY;

#endif // defined(UNICODE)

    return (dwFlags | (dwFlagsInitial & ~WUDF_TRANSPORTMASK));
}

///////////////////////////////////////////////////////////////////////////////
// 

// **************************************************************************
static inline
BOOL IsServerFileDifferentWorker(FILETIME &ftServerTime, 
                                 DWORD dwServerFileSize, HANDLE hFile)
{
    LOG_Block("IsServerFileNewerWorker()");

    FILETIME    ftCreateTime;
    DWORD       cbLocalFile;

    // By default, always return TRUE so we can download a new file..
	BOOL        fRet = TRUE;

    // if we don't have a valid file handle, just return TRUE to download a 
    //  new copy
    if (hFile == INVALID_HANDLE_VALUE)
        goto done;

    cbLocalFile = GetFileSize(hFile, NULL);

	LOG_Internet(_T("IsServerFileNewer: Local size: %d.  Remote size: %d"),
				 cbLocalFile, dwServerFileSize);

    // if the sizes are not equal, then return TRUE
	if (cbLocalFile != dwServerFileSize)
	    goto done;

	if (GetFileTime(hFile, &ftCreateTime, NULL, NULL))
	{
		LOG_Internet(_T("IsServerFileNewer: Local time: %x%0x.  Remote time: %x%0x."),
					 ftCreateTime.dwHighDateTime, ftCreateTime.dwLowDateTime,
					 ftServerTime.dwHighDateTime, ftServerTime.dwLowDateTime);

		// if the local file has a different timestamp, then return TRUE.  
		fRet = (CompareFileTime(&ftCreateTime, &ftServerTime) != 0);
	}

done:
    return fRet;
}

// **************************************************************************
BOOL IsServerFileDifferentW(FILETIME &ftServerTime, DWORD dwServerFileSize, 
                            LPCWSTR wszLocalFile)
{
    LOG_Block("IsServerFileDifferentW()");

    HANDLE  hFile = INVALID_HANDLE_VALUE;
    BOOL    fRet = TRUE;

    // if we have an error opening the file, just return TRUE to download a 
    //  new copy
    hFile = CreateFileW(wszLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL, 
                        OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
    if (hFile == INVALID_HANDLE_VALUE)
    {
        LOG_Internet(_T("IsServerFileDifferent: %ls does not exist."), wszLocalFile);
        return TRUE;
    }
    else
    {
        fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile);
        CloseHandle(hFile);
        return fRet;
    }
}

// **************************************************************************
BOOL IsServerFileDifferentA(FILETIME &ftServerTime, DWORD dwServerFileSize, 
                            LPCSTR szLocalFile)
{
    LOG_Block("IsServerFileDifferentA()");

    HANDLE hFile = INVALID_HANDLE_VALUE;

    // if we have an error opening the file, just return TRUE to download a 
    //  new copy
    hFile = CreateFileA(szLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL, 
                        OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
    if (hFile == INVALID_HANDLE_VALUE)
    {
        LOG_Internet(_T("IsServerFileDifferent: %s does not exist."), szLocalFile);
        return TRUE;
    }
    else
    {
        BOOL fRet;
        fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile);
        CloseHandle(hFile);
        return fRet;
    }
}

// **************************************************************************
// helper function to handle quit events
//
// return TRUE if okay to continue
// return FALSE if we should quit now!
BOOL HandleEvents(HANDLE *phEvents, UINT nEventCount)
{
    LOG_Block("HandleEvents()");

    DWORD dwWait;

    // is there any events to handle?
    if (phEvents == NULL || nEventCount == 0)
        return TRUE;

    // we only want to check the signaled status, so don't bother waiting
    dwWait = WaitForMultipleObjects(nEventCount, phEvents, FALSE, 0);

    if (dwWait == WAIT_TIMEOUT)
    {
        return TRUE;
    }
    else
    {
        LOG_Internet(_T("HandleEvents: A quit event was signaled.  Aborting..."));
        return FALSE;
    }
}


///////////////////////////////////////////////////////////////////////////////
// 

// **************************************************************************
HRESULT PerformDownloadToFile(pfn_ReadDataFromSite pfnRead,
                              HINTERNET hRequest, 
                              HANDLE hFile, DWORD cbFile,
                              DWORD cbBuffer,
                              HANDLE *rghEvents, DWORD cEvents,
                              PFNDownloadCallback fpnCallback, LPVOID pCallbackData,
                              DWORD *pcbDownloaded)
{
    LOG_Block("PerformDownloadToFile()");

    HRESULT hr = S_OK;
    PBYTE   pbBuffer = NULL;
    DWORD   cbDownloaded = 0, cbRead, cbWritten;
    LONG    lCallbackRequest = 0;

    pbBuffer = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbBuffer);
    if (pbBuffer == NULL)
    {
        hr = E_OUTOFMEMORY;
        LOG_ErrorMsg(hr);
        goto done;
    }

    // Download the File
    for(;;)
    {
        if ((*pfnRead)(hRequest, pbBuffer, cbBuffer, &cbRead) == FALSE)
        {
            hr = HRESULT_FROM_WIN32(GetLastError());
            if (FAILED(hr))
            {
                LOG_ErrorMsg(hr);
                goto done;
            }
        }
            
        if (cbRead == 0)
        {
            BYTE bTemp[32];
            
            // Make one final call to WinHttpReadData to commit the file to
            //  Cache.  (the download is not complete otherwise)
            (*pfnRead)(hRequest, bTemp, ARRAYSIZE(bTemp), &cbRead);
            break;
        }
        
        cbDownloaded += cbRead;

        if (fpnCallback != NULL)
        {
            fpnCallback(pCallbackData, DOWNLOAD_STATUS_OK, cbFile, cbRead, NULL, 
                        &lCallbackRequest);
            if (lCallbackRequest == 4)
            {
                // QuitEvent was Signaled.. abort requested. We will do 
                //  another callback and pass the Abort State back
                fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL);
                
                hr = E_ABORT; // set return result to abort.
                goto done;
            }
        }

        if (WriteFile(hFile, pbBuffer, cbRead, &cbWritten, NULL) == FALSE)
        {
            hr = HRESULT_FROM_WIN32(GetLastError());
            LOG_ErrorMsg(hr);
            goto done;
        }

        if (HandleEvents(rghEvents, cEvents) == FALSE)
        {
            // we need to quit the download clean up, send abort event and clean up what we've downloaded
            if (fpnCallback != NULL)
                fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL);

            hr = E_ABORT; // set return result to abort.
            goto done;
        }
    }

    if (pcbDownloaded != NULL)
        *pcbDownloaded = cbDownloaded;

done:
    SafeHeapFree(pbBuffer);

    return hr;

}


///////////////////////////////////////////////////////////////////////////////
//

struct MY_OSVERSIONINFOEX
{
    OSVERSIONINFOEX osvi;
    LCID            lcidCompare;
};
static MY_OSVERSIONINFOEX g_myosvi;
static BOOL               g_fInit = FALSE;

// **************************************************************************
// Loads the current OS version info if needed, and returns a pointer to
//  a cached copy of it.
const OSVERSIONINFOEX* GetOSVersionInfo(void)
{
    if (g_fInit == FALSE)
    {
        OSVERSIONINFOEX* pOSVI = &g_myosvi.osvi;
        
        g_myosvi.osvi.dwOSVersionInfoSize = sizeof(g_myosvi.osvi);
        GetVersionEx((OSVERSIONINFO*)&g_myosvi.osvi);

        // WinXP-specific stuff
        if ((pOSVI->dwMajorVersion > 5) || 
            (pOSVI->dwMajorVersion == 5 && pOSVI->dwMinorVersion >= 1))
            g_myosvi.lcidCompare = LOCALE_INVARIANT;
        else
            g_myosvi.lcidCompare = MAKELCID(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US), SORT_DEFAULT);

        g_fInit = TRUE;
    }
    
    return &g_myosvi.osvi;
}

// **************************************************************************
// String lengths can be -1 if the strings are null-terminated.
int LangNeutralStrCmpNIA(LPCSTR psz1, int cch1, LPCSTR psz2, int cch2)
{
    if (g_fInit == FALSE)
        GetOSVersionInfo();

    int nCompare = CompareStringA(g_myosvi.lcidCompare,
                                  NORM_IGNORECASE,
                                  psz1, cch1,
                                  psz2, cch2);

    return (nCompare - 2); // convert from (1, 2, 3) to (-1, 0, 1)
}

// **************************************************************************
// Finds the first instance of pszSearchFor in pszSearchIn, case-insensitive.
//  Returns an index into pszSearchIn if found, or -1 if not.
//  You can pass -1 for either or both of the lengths.
int LangNeutralStrStrNIA(LPCSTR pszSearchIn, int cchSearchIn, 
                         LPCSTR pszSearchFor, int cchSearchFor)
{
    char chLower, chUpper;
    
    if (cchSearchIn == -1)
        cchSearchIn = lstrlenA(pszSearchIn);
    if (cchSearchFor == -1)
        cchSearchFor = lstrlenA(pszSearchFor);

    // Note: since this is lang-neutral, we can assume no DBCS search chars
    chLower = (char)CharLowerA(MAKEINTRESOURCEA(*pszSearchFor));
    chUpper = (char)CharUpperA(MAKEINTRESOURCEA(*pszSearchFor));

    // Note: since search-for is lang-neutral, we can ignore any DBCS chars 
    //        in search-in
    for (int ichIn = 0; ichIn <= cchSearchIn - cchSearchFor; ichIn++)
    {
        if (pszSearchIn[ichIn] == chLower || pszSearchIn[ichIn] == chUpper)
        {
            if (LangNeutralStrCmpNIA(pszSearchIn + ichIn + 1, cchSearchFor - 1, 
                                     pszSearchFor + 1, cchSearchFor - 1) == 0)
            {
                return ichIn;
            }
        }
    }

    return -1;
}

// **************************************************************************
// Opens the given file and looks for "<html" (case-insensitive) within the
//  first 200 characters. If there are any binary chars before "<html", the
//  file is assumed to *not* be HTML.
// Returns S_OK if so, S_FALSE if not, or an error if file couldn't be opened.
HRESULT IsFileHtml(LPCTSTR pszFileName)
{
    LOG_Block("IsFileHtml()");

    HRESULT hr = S_FALSE;
    LPCSTR  pszFile;
    HANDLE  hFile = INVALID_HANDLE_VALUE;
    HANDLE  hMapping = NULL;
    LPVOID  pvMem = NULL;
    DWORD   cbFile;

    hFile = CreateFile(pszFileName, GENERIC_READ, 
                       FILE_SHARE_READ | FILE_SHARE_WRITE,
                       NULL, OPEN_EXISTING, FILE_FLAG_SEQUENTIAL_SCAN, NULL);
    if (hFile == INVALID_HANDLE_VALUE)
    {
        hr = HRESULT_FROM_WIN32(GetLastError());
        LOG_ErrorMsg(hr);
        goto done;
    }

    cbFile = GetFileSize(hFile, NULL);
    if (cbFile == 0)
        goto done;

    // Only examine the 1st 200 bytes
    if (cbFile > 200)
        cbFile = 200;

    hMapping = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, cbFile, NULL);
    if (hMapping == NULL)
    {
        hr = HRESULT_FROM_WIN32(GetLastError());
        LOG_ErrorMsg(hr);
        goto done;
    }

    pvMem = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, cbFile);
    if (pvMem == NULL)
    {
        hr = HRESULT_FROM_WIN32(GetLastError());
        LOG_ErrorMsg(hr);
        goto done;
    }

    pszFile = (LPCSTR)pvMem;
    int ichHtml = LangNeutralStrStrNIA(pszFile, cbFile, "<html", 5);
    if (ichHtml != -1)
    {
        // Looks like html...
        hr = S_OK;

        // Just make sure there aren't any binary chars before the <HTML> tag
        for (int ich = 0; ich < ichHtml; ich++)
        {
            char ch = pszFile[ich];
            if (ch < 32 && ch != '\t' && ch != '\r' && ch != '\n')
            {
                // Found a binary character (before <HTML>)
                hr = S_FALSE;
                break;
            }
        }
    }

done:
    if (pvMem != NULL)
        UnmapViewOfFile(pvMem);
    if (hMapping != NULL)
        CloseHandle(hMapping);
    if (hFile != INVALID_HANDLE_VALUE)
        CloseHandle(hFile);

    return hr;
}