#include "MetadataPool.h"

#include "utils/Il2CppHashMap.h"
#include "utils/Il2CppHashSet.h"
#include "utils/HashUtils.h"
#include "utils/MemoryPool.h"
#include "metadata/Il2CppTypeHash.h"
#include "metadata/Il2CppTypeCompare.h"
#include "metadata/Il2CppGenericInstHash.h"
#include "metadata/Il2CppGenericInstCompare.h"
#include "metadata/Il2CppGenericMethodHash.h"
#include "metadata/Il2CppGenericMethodCompare.h"
#include "metadata/Il2CppGenericClassHash.h"
#include "metadata/Il2CppGenericClassCompare.h"
#include "vm/MetadataAlloc.h"
#include "vm/MetadataLock.h"
#include "vm/GlobalMetadata.h"

#include "MetadataModule.h"

namespace hybridclr
{
namespace metadata
{
	using il2cpp::utils::HashUtils;

	class Il2CppArrayTypeHash
	{
	public:
		size_t operator()(const Il2CppArrayType* t1) const
		{
			return Hash(t1);
		}

		static size_t Hash(const Il2CppArrayType* t1);
	};

	class Il2CppTypeFullHash
	{
	public:
		size_t operator()(const Il2CppType* t1) const
		{
			return Hash(t1);
		}

		static size_t Hash(const Il2CppType* t1)
		{
			size_t hash = t1->type;

			hash = HashUtils::Combine(hash, t1->byref);

			hash = il2cpp::utils::HashUtils::Combine(hash, t1->attrs);
			//hash = il2cpp::utils::HashUtils::Combine(hash, t1->num_mods);
			hash = il2cpp::utils::HashUtils::Combine(hash, t1->pinned);
			//hash = il2cpp::utils::HashUtils::Combine(hash, t1->valuetype);

			switch (t1->type)
			{
			case IL2CPP_TYPE_VALUETYPE:
			case IL2CPP_TYPE_CLASS:
			{
				return HashUtils::Combine(hash, reinterpret_cast<size_t>(t1->data.typeHandle));
			}
			case IL2CPP_TYPE_SZARRAY:
			case IL2CPP_TYPE_PTR:
			{
				return HashUtils::Combine(hash, Hash(t1->data.type));
			}
			case IL2CPP_TYPE_ARRAY:
			{
				const Il2CppArrayType* a = t1->data.array;
				// dont' compute sizes and lobounds
				hash = HashUtils::Combine(hash, Il2CppArrayTypeHash::Hash(a));
				return hash;
			}
			case IL2CPP_TYPE_GENERICINST:
			{
				return HashUtils::Combine(hash, (size_t)t1->data.generic_class);
				}
			case IL2CPP_TYPE_VAR:
			case IL2CPP_TYPE_MVAR:
				return HashUtils::Combine(hash, reinterpret_cast<size_t>(t1->data.genericParameterHandle));
			default:
				return hash;
			}
			return hash;
		}
	};

	size_t Il2CppArrayTypeHash::Hash(const Il2CppArrayType* t1)
	{
		size_t hash = t1->rank;
		hash = HashUtils::Combine(hash, t1->numsizes);
		hash = HashUtils::Combine(hash, t1->numlobounds);
		hash = HashUtils::Combine(hash, Il2CppTypeFullHash::Hash(t1->etype));

		for (uint8_t i = 0; i < t1->numsizes; ++i)
		{
			hash = HashUtils::Combine(hash, t1->sizes[i]);
		}
		for (uint8_t i = 0; i < t1->numlobounds; ++i)
		{
			hash = HashUtils::Combine(hash, t1->lobounds[i]);
		}
		return hash;
	}

	class Il2CppArrayTypeEqualityComparer
	{
	public:
		bool operator()(const Il2CppArrayType* t1, const Il2CppArrayType* t2) const
		{
			return AreEqual(t1, t2);
		}

		static bool AreEqual(const Il2CppArrayType* t1, const Il2CppArrayType* t2);
	};

	class Il2CppTypeFullEqualityComparer
	{
	public:
		bool operator()(const Il2CppType* t1, const Il2CppType* t2) const
		{
			return AreEqual(t1, t2);
		}

		static bool AreEqual(const Il2CppType* t1, const Il2CppType* t2)
		{
			if (t1->type != t2->type)
			{
				return false;
			}

			if (t1->byref != t2->byref 
				|| t1->attrs != t2->attrs
				|| t1->pinned != t2->pinned)
			{
				return false;
			}

			switch (t1->type)
			{
			case IL2CPP_TYPE_VALUETYPE:
			case IL2CPP_TYPE_CLASS:
				return t1->data.typeHandle == t2->data.typeHandle;
			case IL2CPP_TYPE_PTR:
			case IL2CPP_TYPE_SZARRAY:
				return AreEqual(t1->data.type, t2->data.type);

			case IL2CPP_TYPE_ARRAY:
			{
				const Il2CppArrayType* a1 = t1->data.array;
				const Il2CppArrayType* a2 = t2->data.array;
				return Il2CppArrayTypeEqualityComparer::AreEqual(a1, a2);
			}
			case IL2CPP_TYPE_GENERICINST:
			{
				return t1->data.generic_class == t2->data.generic_class;
				}
			case IL2CPP_TYPE_VAR:
			case IL2CPP_TYPE_MVAR:
				return t1->data.genericParameterHandle == t2->data.genericParameterHandle;
			default:
				return true;
			}
			return true;
		}
	};

	bool Il2CppArrayTypeEqualityComparer::AreEqual(const Il2CppArrayType* a1, const Il2CppArrayType* a2)
	{
		if (a1->rank != a2->rank || a1->numsizes != a2->numsizes || a1->numlobounds != a2->numlobounds)
		{
			return false;
		}
		for (uint8_t i = 0; i < a1->numsizes; ++i)
		{
			if (a1->sizes[i] != a2->sizes[i])
				return false;
		}
		for (uint8_t i = 0; i < a1->numlobounds; ++i)
		{
			if (a1->lobounds[i] != a2->lobounds[i])
				return false;
		}
		return Il2CppTypeFullEqualityComparer::AreEqual(a1->etype, a2->etype);
	}

	static size_t s_methodInfoSize = 0;

	typedef Il2CppHashSet<const Il2CppType*, Il2CppTypeFullHash, Il2CppTypeFullEqualityComparer> Il2CppTypeHashSet;
	static Il2CppTypeHashSet s_Il2CppTypePool;

	typedef Il2CppHashSet<const Il2CppArrayType*, Il2CppArrayTypeHash, Il2CppArrayTypeEqualityComparer> Il2CppArrayTypeHashSet;
	static Il2CppArrayTypeHashSet s_Il2CppArrayTypePool;

	typedef Il2CppHashSet<const Il2CppGenericInst*, il2cpp::metadata::Il2CppGenericInstHash, il2cpp::metadata::Il2CppGenericInstCompare> Il2CppGenericInstHashSet;
	static Il2CppGenericInstHashSet s_Il2CppGenericInstPool;

	typedef Il2CppHashSet<const Il2CppGenericClass*, il2cpp::metadata::Il2CppGenericClassHash, il2cpp::metadata::Il2CppGenericClassCompare> Il2CppGenericClassHashSet;
	static Il2CppGenericClassHashSet s_Il2CppGenericClassPool;

	typedef Il2CppHashSet<const Il2CppGenericMethod*, il2cpp::metadata::Il2CppGenericMethodHash, il2cpp::metadata::Il2CppGenericMethodCompare> Il2CppGenericMethodHashSet;
	static Il2CppGenericMethodHashSet s_Il2CppGenericMethodPool;

	static il2cpp::utils::MemoryPool* s_metadataPool = nullptr;

#if HYBRIDCLR_UNITY_2019
	il2cpp::os::FastMutex s_metadataPoolLock;
#else
	baselib::ReentrantLock s_metadataPoolLock;
#endif

#define USE_METAPOOL_LOCK() il2cpp::os::FastAutoLock lock(&s_metadataPoolLock)


	template<typename T>
	T* MetadataMallocT()
	{
		return (T*)s_metadataPool->Malloc(sizeof(T));
	}

	template<typename T>
	T* MetadataCallocT(size_t count)
	{
		return (T*)s_metadataPool->Calloc(count, sizeof(T));
	}


	void MetadataPool::Initialize()
	{
		s_metadataPool = new il2cpp::utils::MemoryPool();
		//s_Il2CppTypePool.resize(1024 * 64);
		//s_Il2CppArrayTypePool.resize(1024);
		//s_Il2CppGenericInstPool.resize(10240);
		//s_Il2CppGenericClassPool.resize(10240);
		//s_Il2CppGenericMethodPool.resize(10240);
	}


	static int* CopyIntArray(int* arr, uint8_t num)
	{
		int* newArr = MetadataCallocT<int>(num);
		//memcpy(newArr, arr, sizeof(int) * num);
		return newArr;
	}

	static Il2CppType* DeepCloneIl2CppType(const Il2CppType& type)
	{
		Il2CppType* newType = MetadataMallocT<Il2CppType>();
		*newType = type;

#if HYBRIDCLR_UNITY_2021_OR_NEW
		//IL2CPP_ASSERT(!(type.byref && type.valuetype));
#endif
		switch (type.type)
		{
		case IL2CPP_TYPE_VOID:
			break;
		case IL2CPP_TYPE_BOOLEAN:
		case IL2CPP_TYPE_I1:
		case IL2CPP_TYPE_U1:
		case IL2CPP_TYPE_CHAR:
		case IL2CPP_TYPE_I2:
		case IL2CPP_TYPE_U2:
		case IL2CPP_TYPE_I4:
		case IL2CPP_TYPE_U4:
		case IL2CPP_TYPE_R4:
		case IL2CPP_TYPE_I8:
		case IL2CPP_TYPE_U8:
		case IL2CPP_TYPE_R8:
		case IL2CPP_TYPE_I:
		case IL2CPP_TYPE_U:
		case IL2CPP_TYPE_TYPEDBYREF:
		case IL2CPP_TYPE_FNPTR:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			//IL2CPP_ASSERT(type.byref || type.valuetype);
#endif
			break;
		}
		case IL2CPP_TYPE_STRING:
		case IL2CPP_TYPE_OBJECT:
		case IL2CPP_TYPE_BYREF:
		case IL2CPP_TYPE_VAR:
		case IL2CPP_TYPE_MVAR:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			IL2CPP_ASSERT(!type.valuetype);
#endif
			break;
		}
		case IL2CPP_TYPE_CLASS:
		case IL2CPP_TYPE_VALUETYPE:
		{
			//if (type.data.typeHandle)
			//{
			//	hybridclr::metadata::MetadataModule::GetImage(((Il2CppTypeDefinition*)type.data.typeHandle));
			//}
			break;
		}
		case IL2CPP_TYPE_PTR:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			//IL2CPP_ASSERT(type.byref || type.valuetype);
#endif
			newType->data.type = MetadataPool::GetPooledIl2CppType(*type.data.type);
			break;
		}
		case IL2CPP_TYPE_SZARRAY:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			IL2CPP_ASSERT(!type.valuetype);
#endif
			newType->data.type = MetadataPool::GetPooledIl2CppType(*type.data.type);
			break;
		}
		case IL2CPP_TYPE_ARRAY:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			IL2CPP_ASSERT(!type.valuetype);
#endif

			newType->data.array = const_cast<Il2CppArrayType*>(MetadataPool::GetPooledIl2CppArrayType(*type.data.array));
			break;
		}
		case IL2CPP_TYPE_GENERICINST:
		{
#if HYBRIDCLR_UNITY_2021_OR_NEW
			//IL2CPP_ASSERT(type.byref || type.valuetype == type.data.generic_class->type->valuetype);
#endif
			IL2CPP_ASSERT(MetadataPool::GetPooledIl2CppGenericClass(type.data.generic_class->type, type.data.generic_class->context.class_inst) == type.data.generic_class);
			newType->data.generic_class = const_cast<Il2CppGenericClass*>(type.data.generic_class);
			break;
		}
		default:;
		}

		return newType;
	}

	void ValidateIl2CppGenericInst(const Il2CppGenericInst* genericInst)
	{
#if IL2CPP_DEBUG
		for (uint8_t i = 0; i < genericInst->type_argc; ++i)
		{
			const Il2CppType* type = genericInst->type_argv[i];
			IL2CPP_ASSERT(type == MetadataPool::GetPooledIl2CppType(*type));
		}
#endif
	}
	const Il2CppType* MetadataPool::GetPooledIl2CppType(const Il2CppType& type)
	{
		USE_METAPOOL_LOCK();
		auto it = s_Il2CppTypePool.find(&type);
		if (it != s_Il2CppTypePool.end())
			return *it;
		Il2CppType* newType = DeepCloneIl2CppType(type);
		auto ret = s_Il2CppTypePool.insert(newType);
		IL2CPP_ASSERT(ret.second);
		return newType;
	}


	const Il2CppArrayType* MetadataPool::GetPooledIl2CppArrayType(const Il2CppArrayType& arrayType)
	{
		USE_METAPOOL_LOCK();
		auto it = s_Il2CppArrayTypePool.find(&arrayType);
		if (it != s_Il2CppArrayTypePool.end())
			return *it;
		Il2CppArrayType* newArrayType = MetadataMallocT<Il2CppArrayType>();

		const Il2CppArrayType* oldArrayType = &arrayType;
		newArrayType->etype = MetadataPool::GetPooledIl2CppType(*oldArrayType->etype);
		newArrayType->rank = oldArrayType->rank;
		newArrayType->numsizes = oldArrayType->numsizes;
		newArrayType->numlobounds = oldArrayType->numlobounds;
		newArrayType->sizes = oldArrayType->sizes ? CopyIntArray(oldArrayType->sizes, oldArrayType->numsizes) : nullptr;
		newArrayType->lobounds = oldArrayType->lobounds ? CopyIntArray(oldArrayType->lobounds, oldArrayType->numlobounds) : nullptr;
		auto it2 = s_Il2CppArrayTypePool.insert(newArrayType);
		IL2CPP_ASSERT(it2.second);
		IL2CPP_ASSERT(it2.first->key == newArrayType);

		return newArrayType;
	}

	const Il2CppType* MetadataPool::GetPooledArrayType(const Il2CppType& type, const Il2CppType* newEleType)
	{
		USE_METAPOOL_LOCK();
		Il2CppArrayType arrayType = *type.data.array;
		arrayType.etype = newEleType;
		Il2CppType newType = type;
		newType.data.array = const_cast<Il2CppArrayType*>(GetPooledIl2CppArrayType(arrayType));
		return GetPooledIl2CppType(newType);
	}


	static Il2CppGenericInst* AllocateGenericInst(uint32_t typeArgc)
	{
		IL2CPP_ASSERT(typeArgc > 0);
		USE_METAPOOL_LOCK();

		auto newGenericInst = MetadataMallocT<Il2CppGenericInst>();
		newGenericInst->type_argc = typeArgc;
		const Il2CppType** typeArgv = MetadataCallocT<const Il2CppType*>(typeArgc);
		newGenericInst->type_argv = typeArgv;
		return newGenericInst;
	}

	const Il2CppGenericInst* MetadataPool::GetPooledIl2CppGenericInst(const Il2CppType** types, uint32_t typeCount)
	{
		USE_METAPOOL_LOCK();
		IL2CPP_ASSERT(typeCount > 0 && typeCount <= 32);
		const Il2CppType* sharedTypes[32];
		for (uint32_t i = 0; i < typeCount; i++)
		{
			Il2CppType type = *types[i];
			type.attrs = 0;
			sharedTypes[i] = MetadataPool::GetPooledIl2CppType(type);
		}
		Il2CppGenericInst genericInst = { typeCount, sharedTypes };

		auto it = s_Il2CppGenericInstPool.find(&genericInst);
		if (it != s_Il2CppGenericInstPool.end())
			return *it;
		uint32_t typeArgc = genericInst.type_argc;
		Il2CppGenericInst* newGenericInst = AllocateGenericInst(typeArgc);
		std::memcpy(newGenericInst->type_argv, genericInst.type_argv, sizeof(const Il2CppType*) * typeArgc);
		auto ret = s_Il2CppGenericInstPool.insert(newGenericInst);
		IL2CPP_ASSERT(ret.second);
		return newGenericInst;
	}

	const Il2CppGenericInst* MetadataPool::GetPooledIl2CppGenericInstFast(const Il2CppType** types, uint32_t typeCount)
		{
		USE_METAPOOL_LOCK();
#if IL2CPP_DEBUG
		for(uint32_t i = 0; i < typeCount; i++)
		{
			const Il2CppType* type = types[i];
			IL2CPP_ASSERT(type->attrs == 0);
			IL2CPP_ASSERT(type == MetadataPool::GetPooledIl2CppType(*type));
		}
#endif
		Il2CppGenericInst genericInst = { typeCount, types };
		auto it = s_Il2CppGenericInstPool.find(&genericInst);
		if (it != s_Il2CppGenericInstPool.end())
			return *it;
		uint32_t typeArgc = genericInst.type_argc;
		Il2CppGenericInst* newGenericInst = AllocateGenericInst(typeArgc);
		std:memcpy(newGenericInst->type_argv, genericInst.type_argv, sizeof(const Il2CppType*) * typeArgc);
		auto ret = s_Il2CppGenericInstPool.insert(newGenericInst);
		IL2CPP_ASSERT(ret.second);
		return newGenericInst;
	}

	const Il2CppGenericClass* MetadataPool::GetPooledIl2CppGenericClass(const Il2CppType* genericTypeDefinition, const Il2CppGenericInst* inst)
	{
		const Il2CppType* shareType = MetadataPool::GetPooledIl2CppType(*genericTypeDefinition);
		IL2CPP_ASSERT(inst == MetadataPool::GetPooledIl2CppGenericInst(inst->type_argv, inst->type_argc));
		Il2CppGenericClass genericClass;
		genericClass.type = shareType;
		genericClass.context = { inst, nullptr };
		genericClass.cached_class = nullptr;
		USE_METAPOOL_LOCK();
		auto it = s_Il2CppGenericClassPool.find(&genericClass);
		if (it != s_Il2CppGenericClassPool.end())
			return *it;
		Il2CppGenericClass* newGenericClass = MetadataMallocT<Il2CppGenericClass>();
		newGenericClass->type = shareType;
		newGenericClass->context.class_inst = genericClass.context.class_inst;
		newGenericClass->context.method_inst = genericClass.context.method_inst;
		newGenericClass->cached_class = genericClass.cached_class;
		IL2CPP_ASSERT(!genericClass.cached_class);
		//ValidateIl2CppGenericInst(newGenericClass->context.class_inst);
		auto ret = s_Il2CppGenericClassPool.insert(newGenericClass);
		IL2CPP_ASSERT(ret.second);
		return newGenericClass;
	}

	const Il2CppGenericMethod* MetadataPool::GetPooledIl2CppGenericMethod(const Il2CppGenericMethod& genericMethod)
	{
		USE_METAPOOL_LOCK();
		auto it = s_Il2CppGenericMethodPool.find(&genericMethod);
		if (it != s_Il2CppGenericMethodPool.end())
			return *it;
		Il2CppGenericMethod* newGenericMethod = MetadataMallocT<Il2CppGenericMethod>();
		newGenericMethod->methodDefinition = genericMethod.methodDefinition;
		newGenericMethod->context.class_inst = genericMethod.context.class_inst;
		newGenericMethod->context.method_inst = genericMethod.context.method_inst;
		auto ret = s_Il2CppGenericMethodPool.insert(newGenericMethod);
		IL2CPP_ASSERT(ret.second);
		return newGenericMethod;
	}

	void MetadataPool::RegisterIl2CppType(const Il2CppType* type)
	{
		s_Il2CppTypePool.insert(type);
	}

	void MetadataPool::RegisterIl2CppGenericInst(const Il2CppGenericInst* genericInst)
	{
		auto it = s_Il2CppGenericInstPool.insert(genericInst);
		IL2CPP_ASSERT(it.second);
	}

	void MetadataPool::RegisterIl2CppGenericClass(const Il2CppGenericClass* genericClass)
	{
#if IL2CPP_DEBUG
		const Il2CppType* type = genericClass->type;
		IL2CPP_ASSERT(type == MetadataPool::GetPooledIl2CppType(*type));
#endif
		ValidateIl2CppGenericInst(genericClass->context.class_inst);
		s_Il2CppGenericClassPool.insert(genericClass);
	}

	void MetadataPool::RegisterIl2CppGenericMethod(const Il2CppGenericMethod* genericMethod)
	{
		s_Il2CppGenericMethodPool.insert(genericMethod);
	}

	void MetadataPool::WalkAllGenericClasses(il2cpp::metadata::GenericMetadata::GenericClassWalkCallback callback, void* context)
	{
		USE_METAPOOL_LOCK();

		for (const Il2CppGenericClass* gclass : s_Il2CppGenericClassPool)
		{
			Il2CppClass* cachedClass = gclass->cached_class;
			if (cachedClass && !cachedClass->image->assembly->dheAssembly)
			{
				callback(cachedClass, context);
			}
		}
	}
}
}