/*++

Copyright (c) 1994  Microsoft Corporation

Module Name:

    iwinsock.cxx

Abstract:

    Contains functions to load sockets DLL and entry points. Functions and data
    in this module take care of indirecting sockets calls, hence _I_ in front
    of the function name

    Contents:
        IwinsockInitialize
        IwinsockTerminate
        LoadWinsock
        UnloadWinsock
        SafeCloseSocket

Author:

    Richard L Firth (rfirth) 12-Apr-1995

Environment:

    Win32(s) user-mode DLL

Revision History:

    12-Apr-1995 rfirth
        Created

    08-May-1996 arthurbi
        Added support for Socks Firewalls.

    05-Mar-1998 rfirth
        Moved SOCKS support into ICSocket class. Removed SOCKS library
        loading/unloading from this module (revert to pre-SOCKS)

--*/

#include <wininetp.h>

#if defined(__cplusplus)
extern "C" {
#endif

//#define RLF_DEBUG   1

#if INET_DEBUG

#ifdef RLF_DEBUG
#define DPRINTF dprintf
#else
#define DPRINTF (void)
#endif

BOOL
InitDebugSock(
    VOID
    );

VOID
TerminateDebugSock(
    VOID
    );

#else
#define DPRINTF (void)
#endif

//
// private types
//

typedef struct {
    LPSTR FunctionOrdinal;
    FARPROC * FunctionAddress;
} SOCKETS_FUNCTION;


//
// global data
//

GLOBAL
SOCKET
(PASCAL FAR * _I_accept)(
    SOCKET s,
    struct sockaddr FAR *addr,
    int FAR *addrlen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_bind)(
    SOCKET s,
    const struct sockaddr FAR *addr,
    int namelen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_closesocket)(
    SOCKET s
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_connect)(
    SOCKET s,
    const struct sockaddr FAR *name,
    int namelen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_gethostname)(
    char FAR * name,
    int namelen
    ) = NULL;

GLOBAL
LPHOSTENT
(PASCAL FAR * _I_gethostbyname)(
    LPSTR lpHostName
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_getsockname)(
    SOCKET s,
    struct sockaddr FAR *name,
    int FAR * namelen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_getsockopt)(
    SOCKET s,
    int level,
    int optname,
    char FAR * optval,
    int FAR *optlen
    );

GLOBAL
u_long
(PASCAL FAR * _I_htonl)(
    u_long hostlong
    ) = NULL;

GLOBAL
u_short
(PASCAL FAR * _I_htons)(
    u_short hostshort
    ) = NULL;

GLOBAL
unsigned long
(PASCAL FAR * _I_inet_addr)(
    const char FAR * cp
    ) = NULL;

GLOBAL
char FAR *
(PASCAL FAR * _I_inet_ntoa)(
    struct in_addr in
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_ioctlsocket)(
    SOCKET s,
    long cmd,
    u_long FAR *argp
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_listen)(
    SOCKET s,
    int backlog
    ) = NULL;

GLOBAL
u_short
(PASCAL FAR * _I_ntohs)(
    u_short netshort
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_recv)(
    SOCKET s,
    char FAR * buf,
    int len,
    int flags
    ) = NULL;

GLOBAL
int 
(PASCAL FAR * _I_WSARecv)(
    SOCKET s,                                               
    LPWSABUF lpBuffers,                                     
    DWORD dwBufferCount,                                    
    LPDWORD lpNumberOfBytesRecvd,                           
    LPDWORD lpFlags,                                        
    LPWSAOVERLAPPED lpOverlapped,                           
    LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine  
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_recvfrom)(
    SOCKET s,
    char FAR * buf,
    int len,
    int flags,
    struct sockaddr FAR *from, 
    int FAR * fromlen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_select)(
    int nfds,
    fd_set FAR *readfds,
    fd_set FAR *writefds,
    fd_set FAR *exceptfds,
    const struct timeval FAR *timeout
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_send)(
    SOCKET s,
    const char FAR * buf,
    int len,
    int flags
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_WSASend)(
    SOCKET s,                                                  
    LPWSABUF lpBuffers,                                     
    DWORD dwBufferCount,                                    
    LPDWORD lpNumberOfBytesSent,                            
    DWORD dwFlags,                                          
    LPWSAOVERLAPPED lpOverlapped,                           
    LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine  
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_sendto)(
    SOCKET s,
    const char FAR * buf,
    int len,
    int flags,
    const struct sockaddr FAR *to, 
    int tolen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_setsockopt)(
    SOCKET s,
    int level,
    int optname,
    const char FAR * optval,
    int optlen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_shutdown)(
    SOCKET s,
    int how
    ) = NULL;

GLOBAL
SOCKET
(PASCAL FAR * _I_socket)(
    int af,
    int type,
    int protocol
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_WSAStartup)(
    WORD wVersionRequired,
    LPWSADATA lpWSAData
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I_WSACleanup)(
    void
    ) = NULL;

//VENKATKBUG-remove later - for now trap any errors 
GLOBAL
int
(PASCAL FAR * __I_WSAGetLastError)(
    void
    ) = NULL;


int
___I_WSAGetLastError(
    VOID
    )
{
    int nError = __I_WSAGetLastError();
/*
    VENKATK_BUG - OK to have WSAENOTSOCK - could happen for timeout situations.
    INET_ASSERT (nError != WSAENOTSOCK);
 */
    return nError;
}


GLOBAL
int
(PASCAL FAR * _I_WSAGetLastError)(
    void
    ) = ___I_WSAGetLastError;
    
GLOBAL
void
(PASCAL FAR * _I_WSASetLastError)(
    int iError
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _I___WSAFDIsSet)(
    SOCKET,
    fd_set FAR *
    ) = NULL;

#if INET_DEBUG

void SetupSocketsTracing(void);

#endif

//
// private data
//

//
// InitializationLock - protects against multiple threads loading WSOCK32.DLL
// and entry points
//

PRIVATE CCritSec InitializationLock;

//
// hWinsock - NULL when WSOCK32.DLL is not loaded
//

PRIVATE HINSTANCE hWinsock = NULL;

//
// WinsockLoadCount - the number of times we have made calls to LoadWinsock()
// and UnloadWinsock(). When this reaches 0 (again), we can unload the Winsock
// DLL for real
//

PRIVATE DWORD WinsockLoadCount = 0;

//
// SocketsFunctions - this is the list of entry points in WSOCK32.DLL that we
// need to load for WININET.DLL
//

PRIVATE
SOCKETS_FUNCTION
SocketsFunctions[] = {
    "accept",           (FARPROC*)&_I_accept,
    "bind",             (FARPROC*)&_I_bind,
    "closesocket",      (FARPROC*)&_I_closesocket,
    "connect",          (FARPROC*)&_I_connect,
    "getsockname",      (FARPROC*)&_I_getsockname,
    "getsockopt",       (FARPROC*)&_I_getsockopt,
    "htonl",            (FARPROC*)&_I_htonl,
    "htons",            (FARPROC*)&_I_htons,

    "inet_addr",        (FARPROC*)&_I_inet_addr,
    "inet_ntoa",        (FARPROC*)&_I_inet_ntoa,
    "ioctlsocket",      (FARPROC*)&_I_ioctlsocket,

    "listen",           (FARPROC*)&_I_listen,
    "ntohs",            (FARPROC*)&_I_ntohs,
    "recv",             (FARPROC*)&_I_recv,
    "recvfrom",         (FARPROC*)&_I_recvfrom,
    "select",           (FARPROC*)&_I_select,
    "send",             (FARPROC*)&_I_send,
    "sendto",           (FARPROC*)&_I_sendto,
    "setsockopt",       (FARPROC*)&_I_setsockopt,
    "shutdown",         (FARPROC*)&_I_shutdown,
    "socket",           (FARPROC*)&_I_socket,
    "gethostbyname",    (FARPROC*)&_I_gethostbyname,
    "gethostname",      (FARPROC*)&_I_gethostname,
    "WSAGetLastError",  (FARPROC*)&__I_WSAGetLastError,
    "WSASetLastError",  (FARPROC*)&_I_WSASetLastError,
    "WSAStartup",       (FARPROC*)&_I_WSAStartup,
    "WSACleanup",       (FARPROC*)&_I_WSACleanup,
    "__WSAFDIsSet",     (FARPROC*)&_I___WSAFDIsSet,
    "WSARecv",          (FARPROC*)&_I_WSARecv,
    "WSASend",          (FARPROC*)&_I_WSASend
};


//
// private prototypes
//

#if INET_DEBUG

void SetupSocketsTracing(void);

#endif

//
// functions
//


BOOL
IwinsockInitialize(
    VOID
    )

/*++

Routine Description:

    Performs initialization/resource allocation for this module

Arguments:

    None.

Return Value:

    None.

--*/

{
    BOOL fResult;
    //
    // initialize the critical section that protects against multiple threads
    // trying to initialize Winsock
    //

    fResult = InitializationLock.Init();

#if INET_DEBUG
    if (fResult)
        fResult = InitDebugSock();
#endif

    return fResult;
}


VOID
IwinsockTerminate(
    VOID
    )

/*++

Routine Description:

    Performs termination & resource cleanup for this module

Arguments:

    None.

Return Value:

    None.

--*/

{
    InitializationLock.FreeLock();

#if INET_DEBUG
    TerminateDebugSock();
#endif
}


DWORD
LoadWinsock(
    VOID
    )

/*++

Routine Description:

    Dynamically loads Windows sockets library

Arguments:

    None.

Return Value:

    DWORD
        Success - ERROR_SUCCESS

        Failure - Win32 error
                    e.g. LoadLibrary() failure

                  WSA error
                    e.g. WSAStartup() failure

--*/

{
    DEBUG_ENTER((DBG_SOCKETS,
                Dword,
                "LoadWinsock",
                NULL
                ));

    DWORD error = ERROR_SUCCESS;

    //
    // ensure no 2 threads are trying to modify the loaded state of winsock at
    // the same time
    //

    if (!InitializationLock.Lock())
    {
        error = ERROR_NOT_ENOUGH_MEMORY;
        goto quit;
    }

    if (hWinsock == NULL) {

        BOOL failed = FALSE;

        //
        // BUGBUG - read this value from registry
        //

        hWinsock = LoadLibrary("ws2_32");
        
        if (hWinsock == NULL) {
            
            DEBUG_PRINT(SOCKETS,
                        INFO,
                        ("failed to load ws2_32.dll"));
            
            hWinsock = LoadLibrary("wsock32");
        }
        
        if (hWinsock != NULL) {

            //
            // load the entry points
            //

            for (int i = 0; i < ARRAY_ELEMENTS(SocketsFunctions); ++i) {

                FARPROC farProc;

                farProc = GetProcAddress(
                                hWinsock,
                                (LPCSTR)SocketsFunctions[i].FunctionOrdinal
                                );
                if (farProc == NULL) {
                    failed = TRUE;
                    break;
                }
                *SocketsFunctions[i].FunctionAddress = farProc;
            }
            if (!failed) {

                //
                // although we need a WSADATA for WSAStartup(), it is an
                // expendible structure (not required for any other sockets
                // calls)
                //

                WSADATA wsaData;

                error = _I_WSAStartup(0x0101, &wsaData);
                if (error == ERROR_SUCCESS) {

                    DEBUG_PRINT(SOCKETS,
                                INFO,
                                ("winsock description: %q\n",
                                wsaData.szDescription
                                ));

                    int stringLen;

                    stringLen = lstrlen(wsaData.szDescription);
                    if (strnistr(wsaData.szDescription, "novell", stringLen)
                    && strnistr(wsaData.szDescription, "wsock32", stringLen)) {

                        DEBUG_PRINT(SOCKETS,
                                    INFO,
                                    ("running on Novell Client32 stack\n"
                                    ));

                        GlobalRunningNovellClient32 = TRUE;
                    }
#if INET_DEBUG
                    SetupSocketsTracing();
#endif
                } else {
                    failed = TRUE;
                }
            }
        } else {
            failed = TRUE;
        }

        //
        // if we failed to find an entry point or WSAStartup() returned an error
        // then unload the library
        //

        if (failed) {

            //
            // important: there should be no API calls between determining the
            // failure and coming here to get the error code
            //
            // if error == ERROR_SUCCESS then we have to get the last error, else
            // it is the error returned by WSAStartup()
            //

            if (error == ERROR_SUCCESS) {
                error = GetLastError();

                INET_ASSERT(error != ERROR_SUCCESS);

            }
            UnloadWinsock();
        }
    } else {

        //
        // just increment the number of times we have called LoadWinsock()
        // without a corresponding call to UnloadWinsock();
        //

        ++WinsockLoadCount;
    }

    InitializationLock.Unlock();

    //
    // if we failed for any reason, need to report that TCP/IP not available
    //

    if (error != ERROR_SUCCESS) {
        error = ERROR_NOT_SUPPORTED;
    }

quit:
    DEBUG_LEAVE(error);

    return error;
}


VOID
UnloadWinsock(
    VOID
    )

/*++

Routine Description:

    Unloads winsock DLL and prepares hWinsock and SocketsFunctions[] for reload

Arguments:

    None.

Return Value:

    None.

--*/

{
    DEBUG_ENTER((DBG_SOCKETS,
                 None,
                 "UnloadWinsock",
                 NULL
                 ));

    //
    // ensure no 2 threads are trying to modify the loaded state of winsock at
    // the same time
    //

    if (!InitializationLock.Lock())
    {
        goto quit;
    }

    //
    // only unload the DLL if it has been mapped into process memory
    //

    if (hWinsock != NULL) {

        //
        // and only if this is the last load instance
        //

        if (WinsockLoadCount == 0) {

            INET_ASSERT(_I_WSACleanup != NULL);

            if (_I_WSACleanup != NULL) {

                //
                // need to terminate async support too - it is reliant on
                // Winsock
                //

                //called only from LoadWinsock which is called only from INTERNET_HANDLE_OBJECT()
                //so not in dynamic unload, so alrite to cleanup.
                TerminateAsyncSupport(TRUE);

                int serr = _I_WSACleanup();

                if (serr != 0) {

                    DEBUG_PRINT(SOCKETS,
                                ERROR,
                                ("WSACleanup() returns %d; WSA error = %d\n",
                                serr,
                                (_I_WSAGetLastError != NULL)
                                    ? _I_WSAGetLastError()
                                    : -1
                                ));

                }
            }
            for (int i = 0; i < ARRAY_ELEMENTS(SocketsFunctions); ++i) {
                *SocketsFunctions[i].FunctionAddress = (FARPROC)NULL;
            }
            FreeLibrary(hWinsock);
            hWinsock = NULL;
        } else {

            //
            // if there have been multiple virtual loads, then just reduce the
            // load count
            //

            --WinsockLoadCount;
        }
    }

    InitializationLock.Unlock();

quit:
    DEBUG_LEAVE(0);
}


DWORD
SafeCloseSocket(
    IN SOCKET Socket
    )

/*++

Routine Description:

    closesocket() call protected by exception handler in case winsock DLL has
    been unloaded by system before Wininet DLL unloaded

Arguments:

    Socket  - socket handle to close

Return Value:

    DWORD
        Success - ERROR_SUCCESS

        Failure - socket error mapped to ERROR_WINHTTP_ error

--*/

{
    int serr;

    __try {
        serr = _I_closesocket(Socket);
    } __except(EXCEPTION_EXECUTE_HANDLER) {
        serr = 0;
    }
    ENDEXCEPT
    return (serr == SOCKET_ERROR)
        ? MapInternetError(_I_WSAGetLastError())
        : ERROR_SUCCESS;
}

CWrapOverlapped* GetWrapOverlappedObject(LPVOID lpAddress)
{
    return CONTAINING_RECORD(lpAddress, CWrapOverlapped, m_Overlapped);
}

#if INET_DEBUG

//
// debug data types
//

SOCKET
PASCAL FAR
_II_socket(
    int af,
    int type,
    int protocol
    );

int
PASCAL FAR
_II_closesocket(
    SOCKET s
    );

SOCKET
PASCAL FAR
_II_accept(
    SOCKET s,
    struct sockaddr FAR *addr,
    int FAR *addrlen
    );

GLOBAL
SOCKET
(PASCAL FAR * _P_accept)(
    SOCKET s,
    struct sockaddr FAR *addr,
    int FAR *addrlen
    ) = NULL;

GLOBAL
int
(PASCAL FAR * _P_closesocket)(
    SOCKET s
    ) = NULL;

GLOBAL
SOCKET
(PASCAL FAR * _P_socket)(
    int af,
    int type,
    int protocol
    ) = NULL;

#define MAX_STACK_TRACE     5
#define MAX_SOCK_ENTRIES    1000

typedef struct _DEBUG_SOCK_ENTRY {
    SOCKET Socket;
    DWORD StackTraceLength;
    PVOID StackTrace[ MAX_STACK_TRACE ];
} DEBUG_SOCK_ENTRY, *LPDEBUG_SOCK_ENTRY;

CCritSec DebugSockLock;
DEBUG_SOCK_ENTRY GlobalSockEntry[MAX_SOCK_ENTRIES];

DWORD GlobalSocketsCount = 0;


#define LOCK_DEBUG_SOCK()   (DebugSockLock.Lock())
#define UNLOCK_DEBUG_SOCK() (DebugSockLock.Unlock())

HINSTANCE NtDllHandle;

typedef USHORT (*RTL_CAPTURE_STACK_BACK_TRACE)(
   IN ULONG FramesToSkip,
   IN ULONG FramesToCapture,
   OUT PVOID *BackTrace,
   OUT PULONG BackTraceHash
   );

RTL_CAPTURE_STACK_BACK_TRACE pRtlCaptureStackBackTrace;

BOOL
InitDebugSock(
    VOID
    )
{
    memset( GlobalSockEntry, 0x0, sizeof(GlobalSockEntry) );
    GlobalSocketsCount = 0;

    if (!DebugSockLock.Init())
    {
        INET_ASSERT(FALSE);
        return FALSE;
    }
    else
    {
        return TRUE;
    }
}

VOID
TerminateDebugSock(
    VOID
    )
{
    DebugSockLock.FreeLock();
}

VOID
SetupSocketsTracing(
    VOID
    )
{
    if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
        return ;
    }
    if (!IsPlatformWinNT()) {
        return ;
    }
    if ((NtDllHandle = LoadLibrary("ntdll.dll")) == NULL) {
        return ;
    }
    if ((pRtlCaptureStackBackTrace =
        (RTL_CAPTURE_STACK_BACK_TRACE)
            GetProcAddress(NtDllHandle, "RtlCaptureStackBackTrace")) == NULL) {
        FreeLibrary(NtDllHandle);
        return ;
    }

//#ifdef DONT_DO_FOR_NOW
    _P_accept = _I_accept;
    _I_accept = _II_accept;
    _P_closesocket = _I_closesocket;
    _I_closesocket = _II_closesocket;
    _P_socket = _I_socket;
    _I_socket = _II_socket;
//#endif
}

VOID
AddSockEntry(
    SOCKET S
    )
{
    DWORD i;
    DWORD Hash;

    if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
        return ;
    }

    LOCK_DEBUG_SOCK();

    //
    // search for a free entry.
    //

    for( i = 0; i < MAX_SOCK_ENTRIES; i++ ) {

        if( GlobalSockEntry[i].Socket == 0 ) {

            DWORD Hash;

            //
            // found a free entry.
            //

            GlobalSockEntry[i].Socket = S;

            //
            // get caller stack.
            //

#if i386
            Hash = 0;

            GlobalSockEntry[i].StackTraceLength =
                pRtlCaptureStackBackTrace(
                    2,
                    MAX_STACK_TRACE,
                    GlobalSockEntry[i].StackTrace,
                    &Hash );
#else // i386
            GlobalSockEntry[i].StackTraceLength = 0;
#endif // i386


            GlobalSocketsCount++;

            DEBUG_PRINT(SOCKETS,
                        INFO,
                        ("socket count = %ld\n",
                        GlobalSocketsCount
                        ));

            DPRINTF("%d sockets\n", GlobalSocketsCount);

            UNLOCK_DEBUG_SOCK();
            return;
        }
    }

    //
    // we have reached a high handle limit, which is unusal, needs to be
    // debugged.
    //

    INET_ASSERT( FALSE );
    UNLOCK_DEBUG_SOCK();

    return;
}

VOID
RemoveSockEntry(
    SOCKET S
    )
{
    DWORD i;

    if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
        return ;
    }

    LOCK_DEBUG_SOCK();

    for( i = 0; i < MAX_SOCK_ENTRIES; i++ ) {

        if( GlobalSockEntry[i].Socket == S ) {

            //
            // found the entry. Free it now.
            //

            memset( &GlobalSockEntry[i], 0x0, sizeof(DEBUG_SOCK_ENTRY) );

            GlobalSocketsCount--;

#ifdef IWINSOCK_DEBUG_PRINT

            DEBUG_PRINT(SOCKETS,
                        INFO,
                        ("count(%ld), RemoveSock(%lx)\n",
                        GlobalSocketsCount,
                        S
                        ));

#endif // IWINSOCK_DEBUG_PRINT

            DPRINTF("%d sockets\n", GlobalSocketsCount);

            UNLOCK_DEBUG_SOCK();
            return;
        }
    }

#ifdef IWINSOCK_DEBUG_PRINT

    DEBUG_PRINT(SOCKETS,
                INFO,
                ("count(%ld), UnknownSock(%lx)\n",
                GlobalSocketsCount,
                S
                ));

#endif // IWINSOCK_DEBUG_PRINT

    //
    // socket entry is not found.
    //

    // INET_ASSERT( FALSE );

    UNLOCK_DEBUG_SOCK();
    return;
}

SOCKET
PASCAL FAR
_II_socket(
    int af,
    int type,
    int protocol
    )
{
    SOCKET S;

    S = _P_socket( af, type, protocol );
    AddSockEntry( S );
    return( S );
}

int
PASCAL FAR
_II_closesocket(
    SOCKET s
    )
{
    int Ret;

    RemoveSockEntry( s );
    Ret = _P_closesocket( s );
    return( Ret );
}

SOCKET
PASCAL FAR
_II_accept(
    SOCKET s,
    struct sockaddr FAR *addr,
    int FAR *addrlen
    )
{
    SOCKET S;

    S = _P_accept( s, addr, addrlen );
    AddSockEntry( S );
    return( S );

}

VOID
IWinsockCheckSockets(
    VOID
    )
{
    DEBUG_PRINT(SOCKETS,
                INFO,
                ("GlobalSocketsCount = %d\n",
                GlobalSocketsCount
                ));

    for (DWORD i = 0; i < MAX_SOCK_ENTRIES; ++i) {

        SOCKET sock;

        if ((sock = GlobalSockEntry[i].Socket) != 0) {

            DEBUG_PRINT(SOCKETS,
                        INFO,
                        ("Socket %#x\n",
                        sock
                        ));

        }
    }
}

#endif // INET_DEBUG

#if defined(__cplusplus)
}
#endif