| // Copyright 2017 The Chromium Authors. All rights reserved. | 
 | // Use of this source code is governed by a BSD-style license that can be | 
 | // found in the LICENSE file. | 
 |  | 
 | #include "base/win/com_init_util.h" | 
 |  | 
 | #include <windows.h> | 
 | #include <winternl.h> | 
 |  | 
 | namespace base { | 
 | namespace win { | 
 |  | 
 | #if DCHECK_IS_ON() | 
 |  | 
 | namespace { | 
 |  | 
 | const char kComNotInitialized[] = "COM is not initialized on this thread."; | 
 |  | 
 | // Derived from combase.dll. | 
 | struct OleTlsData { | 
 |   enum ApartmentFlags { | 
 |     LOGICAL_THREAD_REGISTERED = 0x2, | 
 |     STA = 0x80, | 
 |     MTA = 0x140, | 
 |   }; | 
 |  | 
 |   void* thread_base; | 
 |   void* sm_allocator; | 
 |   DWORD apartment_id; | 
 |   DWORD apartment_flags; | 
 |   // There are many more fields than this, but for our purposes, we only care | 
 |   // about |apartment_flags|. Correctly declaring the previous types allows this | 
 |   // to work between x86 and x64 builds. | 
 | }; | 
 |  | 
 | OleTlsData* GetOleTlsData() { | 
 |   TEB* teb = NtCurrentTeb(); | 
 |   return reinterpret_cast<OleTlsData*>(teb->ReservedForOle); | 
 | } | 
 |  | 
 | ComApartmentType GetComApartmentTypeForThread() { | 
 |   OleTlsData* ole_tls_data = GetOleTlsData(); | 
 |   if (!ole_tls_data) | 
 |     return ComApartmentType::NONE; | 
 |  | 
 |   if (ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::STA) | 
 |     return ComApartmentType::STA; | 
 |  | 
 |   if ((ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::MTA) == | 
 |       OleTlsData::ApartmentFlags::MTA) { | 
 |     return ComApartmentType::MTA; | 
 |   } | 
 |  | 
 |   return ComApartmentType::NONE; | 
 | } | 
 |  | 
 | }  // namespace | 
 |  | 
 | void AssertComInitialized(const char* message) { | 
 |   if (GetComApartmentTypeForThread() != ComApartmentType::NONE) | 
 |     return; | 
 |  | 
 |   // COM worker threads don't always set up the apartment, but they do perform | 
 |   // some thread registration, so we allow those. | 
 |   OleTlsData* ole_tls_data = GetOleTlsData(); | 
 |   if (ole_tls_data && (ole_tls_data->apartment_flags & | 
 |                        OleTlsData::ApartmentFlags::LOGICAL_THREAD_REGISTERED)) { | 
 |     return; | 
 |   } | 
 |  | 
 |   NOTREACHED() << (message ? message : kComNotInitialized); | 
 | } | 
 |  | 
 | void AssertComApartmentType(ComApartmentType apartment_type) { | 
 |   DCHECK_EQ(apartment_type, GetComApartmentTypeForThread()); | 
 | } | 
 |  | 
 | #endif  // DCHECK_IS_ON() | 
 |  | 
 | }  // namespace win | 
 | }  // namespace base |