/*++

Copyright (c) 1987-1996  Microsoft Corporation

Module Name:

    joincrypt.c

Abstract:

    Authentication related functions required by netjoin

Author:

    kumarp 29-May-1999

Notes:
    The functions in this file used to be made avaialbe by
    including net\svcdlls\logonsrv\server\ssiauth.c. This led to several
    problems due to which, the functions are now copied from that file
    into this separate file.

--*/


#pragma hdrstop

#define WKSTA_NETLOGON
#define NETSETUP_JOIN

#include <netsetp.h>
#include <crypt.h>
#include <ntsam.h>
#include <logonmsv.h>
#include <lmshare.h>
#include <wincrypt.h>
#include <netlogon.h>
#include <logonp.h>
#include <logonmsv.h>
#include <ssi.h>
#include <wchar.h>
#include "joinp.h"

LONG NlGlobalSessionCounter = 0;
HCRYPTPROV NlGlobalCryptProvider = (HCRYPTPROV)NULL;


#define NlPrint(x)
#define NlpDumpBuffer(x,y,z)

BOOLEAN
NlGenerateRandomBits(
    PUCHAR Buffer,
    ULONG  BufferLen
    );
#define NlQuerySystemTime( _Time ) GetSystemTimeAsFileTime( (LPFILETIME)(_Time) )


VOID
NlComputeChallenge(
    OUT PNETLOGON_CREDENTIAL Challenge
    );

VOID
NlComputeCredentials(
    IN PNETLOGON_CREDENTIAL Challenge,
    OUT PNETLOGON_CREDENTIAL Credential,
    IN PNETLOGON_SESSION_KEY SessionKey
    );



NTSTATUS
NlMakeSessionKey(
    IN ULONG NegotiatedFlags,
    IN PNT_OWF_PASSWORD CryptKey,
    IN PNETLOGON_CREDENTIAL ClientChallenge,
    IN PNETLOGON_CREDENTIAL ServerChallenge,
    OUT PNETLOGON_SESSION_KEY SessionKey
    )
/*++

Routine Description:

      Build an encryption key for use in authentication for
      this RequestorName.

Arguments:

      NegotiatedFlags - Determines the strength of the key.

      CryptKey -- The OWF password of the user account being used.

      ClientChallenge --  8 byte (64 bit) number generated by caller

      ServerChallenge -- 8 byte (64 bit) number generated by primary

      SessionKey --  16 byte (128 bit) number generated at both ends
        If the key strength is weak, the last 64 bits will be zero.

Return Value:

    TRUE: Success
    FALSE: Failure

    NT status code.

--*/
{
    NTSTATUS Status;
    BLOCK_KEY BlockKey;
    NETLOGON_SESSION_KEY TempSessionKey;

#ifndef NETSETUP_JOIN
    PCHECKSUM_BUFFER CheckBuffer = NULL;
    PCHECKSUM_FUNCTION Check;
#endif // NETSETUP_JOIN

    //
    // Start with a zero key
    //
    RtlZeroMemory(SessionKey, sizeof(NETLOGON_SESSION_KEY));

#ifdef NETSETUP_JOIN
    UNREFERENCED_PARAMETER( NegotiatedFlags );
#else // NETSETUP_JOIN
    //
    // If the caller wants a strong key,
    //  Compute it.
    //
    if ( NegotiatedFlags & NETLOGON_SUPPORTS_STRONG_KEY ) {

        // PCRYPTO_SYSTEM CryptSystem;

        UCHAR LocalChecksum[sizeof(*SessionKey)];
        // ULONG OutputSize;

        //
        // Initialize the checksum routines.
        //

        Status = CDLocateCheckSum( KERB_CHECKSUM_MD5_HMAC, &Check);
        if (!NT_SUCCESS(Status)) {
            NlPrint(( NL_CRITICAL,"NlMakeSessionKey: Failed to load checksum routines: 0x%x\n", Status));
            goto Cleanup;
        }

        ASSERT(Check->CheckSumSize <= sizeof(LocalChecksum));

        Status = Check->InitializeEx(
                    (LPBYTE)CryptKey,
                    sizeof( *CryptKey ),
                    0,              // no message type
                    &CheckBuffer );

        if (!NT_SUCCESS(Status)) {
            NlPrint(( NL_CRITICAL,"NlMakeSessionKey: Failed to initialize checksum routines: 0x%x\n", Status));
            goto Cleanup;
        }


        //
        // Sum in the client challenge, a constant, and the server challenge
        //

        Check->Sum( CheckBuffer,
                    sizeof(*ClientChallenge),
                    (PUCHAR)ClientChallenge );

        Check->Sum( CheckBuffer,
                    sizeof(*ServerChallenge),
                    (PUCHAR)ServerChallenge );

        //
        // Finish the checksum
        //

        (void) Check->Finalize(CheckBuffer, LocalChecksum);


        //
        // Copy the checksum into the message.
        //

        ASSERT( sizeof(LocalChecksum) >= sizeof(*SessionKey) );
        RtlCopyMemory( SessionKey, LocalChecksum, sizeof(*SessionKey) );


    //
    // Compute weaker (but backward compatible key)
    //
    } else {
#endif // NETSETUP_JOIN

        //
        // we will have a 128 bit key (64 bit encrypted rest padded with 0s)
        //
        // SessionKey = C + P  (arithmetic sum ignore carry)
        //

        *((unsigned long * ) SessionKey) =
            *((unsigned long * ) ClientChallenge) +
            *((unsigned long * ) ServerChallenge);

        *((unsigned long * )((LPBYTE)SessionKey + 4)) =
            *((unsigned long * )((LPBYTE)ClientChallenge + 4)) +
            *((unsigned long * )((LPBYTE)ServerChallenge + 4));


        //
        // CryptKey is our 16 byte key to be used as described in codespec
        // use first 7 bytes of CryptKey for first encryption
        //

        RtlCopyMemory( &BlockKey, CryptKey, BLOCK_KEY_LENGTH );

        Status = RtlEncryptBlock(
                    (PCLEAR_BLOCK) SessionKey,   // Clear text
                    &BlockKey,                  // Key
                    (PCYPHER_BLOCK) &TempSessionKey);    // Cypher Block

        if ( !NT_SUCCESS( Status ) ) {
            goto Cleanup;
        }


        //
        // Further encrypt the encrypted "SessionKey" using upper 7 bytes
        //

        ASSERT( LM_OWF_PASSWORD_LENGTH == 2*BLOCK_KEY_LENGTH+2 );

        RtlCopyMemory( &BlockKey,
                       ((PUCHAR)CryptKey) + 2 + BLOCK_KEY_LENGTH,
                       BLOCK_KEY_LENGTH );

        Status = RtlEncryptBlock(
                    (PCLEAR_BLOCK) &TempSessionKey,   // Clear text
                    &BlockKey,                  // Key
                    (PCYPHER_BLOCK) SessionKey);    // Cypher Block

        if ( !NT_SUCCESS( Status ) ) {
            goto Cleanup;
        }
#ifndef NETSETUP_JOIN
    }
#endif // NETSETUP_JOIN

Cleanup:
#ifndef NETSETUP_JOIN
    if (CheckBuffer != NULL) {
        Status = Check->Finish(&CheckBuffer);

        if (!NT_SUCCESS(Status)) {
            NlPrint(( NL_CRITICAL,"NlMakeSessionKey: Failed to finish checksum: 0x%x\n", Status));
        }
    }
#endif // NETSETUP_JOIN

    return Status;
}


VOID
NlComputeChallenge(
    OUT PNETLOGON_CREDENTIAL Challenge
    )

/*++

Routine Description:

    Generates a 64 bit challenge

Arguments:

    Challenge - Returns the computed challenge

Return Value:

    None.

--*/
{

    //
    // Use an ideal random bit generator.
    //

    if (!NlGenerateRandomBits( (LPBYTE)Challenge, sizeof(*Challenge) )) {
        NlPrint((NL_CRITICAL, "Can't NlGenerateRandomBits\n" ));
    }

    return;
}

VOID
NlComputeCredentials(
    IN PNETLOGON_CREDENTIAL Challenge,
    OUT PNETLOGON_CREDENTIAL Credential,
    IN PNETLOGON_SESSION_KEY SessionKey
    )
/*++

Routine Description:

    Calculate the credentials by encrypting the 8 byte
    challenge with first 7 bytes of sessionkey and then
    further encrypting it by next 7 bytes of sessionkey.

Arguments:

    Challenge  - Supplies the 8 byte (64 bit) challenge

    Credential - Returns the 8  byte (64 bit) number generated

    SessionKey - Supplies 14 byte (112 bit) encryption key
        The buffer is 16 bytes (128 bits) long.  For a weak key, the trailing 8 bytes
        are zero.  For a strong key, this routine ingored that trailing 2 bytes of
        useful key.

Return Value:

    NONE

--*/
{
    NTSTATUS Status;
    BLOCK_KEY BlockKey;
    CYPHER_BLOCK IntermediateBlock;

    RtlZeroMemory(Credential, sizeof(*Credential));

    //
    // use first 7 bytes of SessionKey for first encryption
    //

    RtlCopyMemory( &BlockKey, SessionKey, BLOCK_KEY_LENGTH );

    Status = RtlEncryptBlock( (PCLEAR_BLOCK) Challenge, // Cleartext
                              &BlockKey,                // Key
                              &IntermediateBlock );     // Cypher Block

    ASSERT( NT_SUCCESS(Status) );

    //
    // further encrypt the encrypted Credential using next 7 bytes
    //

    RtlCopyMemory( &BlockKey,
                   ((PUCHAR)SessionKey) + BLOCK_KEY_LENGTH,
                   BLOCK_KEY_LENGTH );

    Status = RtlEncryptBlock( (PCLEAR_BLOCK) &IntermediateBlock, // Cleartext
                              &BlockKey,                // Key
                              Credential );             // Cypher Block

    ASSERT( NT_SUCCESS(Status) );

    return;

}

BOOLEAN
NlGenerateRandomBits(
    PUCHAR Buffer,
    ULONG  BufferLen
    )
/*++

Routine Description:

    Generates random bits

Arguments:

    pBuffer - Buffer to fill

    cbBuffer - Number of bytes in buffer

Return Value:

    Status of the operation.

--*/

{
    if( !CryptGenRandom( NlGlobalCryptProvider, BufferLen, ( LPBYTE )Buffer ) )
    {
        NlPrint((NL_CRITICAL, "CryptGenRandom failed with %lu\n", GetLastError() ));
        return FALSE;
    }

    return TRUE;
}



NET_API_STATUS
NET_API_FUNCTION
NetpValidateMachineAccount(
    IN  LPWSTR      lpDc,
    IN  LPWSTR      lpDomain,
    IN  LPWSTR      lpMachine,
    IN  LPWSTR      lpPassword
    )
/*++

Routine Description:

    Performs validation that the machine account exists and has the same password we expect

    The internals of this function were lifted completely from SimulateFullSync() in
    ..\svcdlls\logonsrv\server\nltest.c,

Arguments:

    lpDc -- Name of the Dc
    lpDomain -- Name of the domain
    lpMachine -- Current machine
    lpPassword -- Password that should be on the account.


Returns:

    NERR_Success -- Success

--*/
{
    NTSTATUS Status = STATUS_SUCCESS;

    NETLOGON_CREDENTIAL ServerChallenge;
    NETLOGON_CREDENTIAL ClientChallenge;
    NETLOGON_CREDENTIAL ComputedServerCredential;
    NETLOGON_CREDENTIAL ReturnedServerCredential;
    NETLOGON_CREDENTIAL AuthenticationSeed;
    NETLOGON_SESSION_KEY SessionKey;
    WCHAR AccountName[SSI_ACCOUNT_NAME_LENGTH+1];
    UNICODE_STRING Password;
    NT_OWF_PASSWORD NtOwfPassword;

    UNREFERENCED_PARAMETER( lpDomain );

    ASSERT( lpPassword );

    //
    // initialize Crypto Provider.
    // (required for NlComputeChallenge).
    //

    if ( !CryptAcquireContext(
                    &NlGlobalCryptProvider,
                    NULL,
                    NULL,
                    PROV_RSA_FULL,
                    CRYPT_VERIFYCONTEXT
                    ))
    {
        NlGlobalCryptProvider = (HCRYPTPROV)NULL;
        return (NET_API_STATUS)GetLastError();
    }

    //
    // Prepare our challenge
    //

    NlComputeChallenge( &ClientChallenge );

    //
    // free cryptographic service provider.
    //
    if ( NlGlobalCryptProvider ) {
        CryptReleaseContext( NlGlobalCryptProvider, 0 );
        NlGlobalCryptProvider = (HCRYPTPROV)NULL;
    }


    //
    // Get the primary's challenge
    //

    Status = I_NetServerReqChallenge(lpDc,
                                     lpMachine,
                                     &ClientChallenge,
                                     &ServerChallenge );

    if ( !NT_SUCCESS( Status ) ) {

        goto ValidateMachineAccountError;
    }


    Password.Length = Password.MaximumLength = wcslen(lpPassword) * sizeof(WCHAR);
    Password.Buffer = lpPassword;

    //
    // Compute the NT OWF password for this user.
    //

    Status = RtlCalculateNtOwfPassword( &Password, &NtOwfPassword );

    if ( !NT_SUCCESS( Status ) ) {

        goto ValidateMachineAccountError;

    }


    //
    // Actually compute the session key given the two challenges and the
    // password.
    //

    NlMakeSessionKey(
#if(_WIN32_WINNT >= 0x0500)
                      0,
#endif
                      &NtOwfPassword,
                      &ClientChallenge,
                      &ServerChallenge,
                      &SessionKey );

    //
    // Prepare credentials using our challenge.
    //

    NlComputeCredentials( &ClientChallenge,
                          &AuthenticationSeed,
                          &SessionKey );

    //
    // Send these credentials to primary. The primary will compute
    // credentials using the challenge supplied by us and compare
    // with these. If both match then it will compute credentials
    // using its challenge and return it to us for verification
    //

    wcscpy( AccountName, lpMachine );
    wcscat( AccountName, SSI_ACCOUNT_NAME_POSTFIX);

    Status = I_NetServerAuthenticate( lpDc,
                                      AccountName,
                                      WorkstationSecureChannel,
                                      lpMachine,
                                      &AuthenticationSeed,
                                      &ReturnedServerCredential );

    if ( !NT_SUCCESS( Status ) ) {

        goto ValidateMachineAccountError;

    }


    //
    // The DC returned a server credential to us,
    //  ensure the server credential matches the one we would compute.
    //

    NlComputeCredentials( &ServerChallenge,
                          &ComputedServerCredential,
                          &SessionKey);


    if (RtlCompareMemory( &ReturnedServerCredential,
                          &ComputedServerCredential,
                          sizeof(ReturnedServerCredential)) !=
                          sizeof(ReturnedServerCredential)) {

        Status =  STATUS_ACCESS_DENIED;
    }


ValidateMachineAccountError:

    if ( Status == STATUS_ACCESS_DENIED ) {

        Status = STATUS_LOGON_FAILURE;
    }

    if ( !NT_SUCCESS( Status ) ) {

        NetpLog(( "Failed to validate machine account for %ws against %ws: 0x%lx\n",
                  lpMachine, lpDc, Status ));
    }



    return( RtlNtStatusToDosError( Status ) );
}