Skip to content
This repository was archived by the owner on Jan 12, 2024. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 78 additions & 49 deletions src/QirRuntime/lib/Simulators/FullstateSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
#include <assert.h>
#include <bitset>
#include <complex>
#include <exception>
#include <iostream>
#include <memory>
#include <vector>

#include "SimFactory.hpp"
#include "QuantumApi_I.hpp"
#include "SimFactory.hpp"

using namespace std;

Expand All @@ -21,15 +22,36 @@ typedef HMODULE QUANTUM_SIMULATOR;
typedef void* QUANTUM_SIMULATOR;
#endif

QUANTUM_SIMULATOR LoadQuantumSimulator()
namespace
{
#ifdef _WIN32
return ::LoadLibraryA("Microsoft.Quantum.Simulator.Runtime.dll");
const char* FULLSTATESIMULATORLIB = "Microsoft.Quantum.Simulator.Runtime.dll";
#elif __APPLE__
return ::dlopen("libMicrosoft.Quantum.Simulator.Runtime.dylib", RTLD_LAZY);
const char* FULLSTATESIMULATORLIB = "libMicrosoft.Quantum.Simulator.Runtime.dylib";
#else
const char* FULLSTATESIMULATORLIB = "libMicrosoft.Quantum.Simulator.Runtime.so";
#endif

QUANTUM_SIMULATOR LoadQuantumSimulator()
{
QUANTUM_SIMULATOR handle = 0;
#ifdef _WIN32
handle = ::LoadLibraryA(FULLSTATESIMULATORLIB);
if (handle == NULL)
{
throw std::runtime_error(
std::string("Failed to load ") + FULLSTATESIMULATORLIB +
" (error code: " + std::to_string(GetLastError()) + ")");
}
#else
return ::dlopen("libMicrosoft.Quantum.Simulator.Runtime.so", RTLD_LAZY);
handle = ::dlopen(FULLSTATESIMULATORLIB, RTLD_LAZY);
if (handle == nullptr)
{
throw std::runtime_error(
std::string("Failed to load ") + FULLSTATESIMULATORLIB + " (" + ::dlerror() + ")");
}
#endif
return handle;
}

bool UnloadQuantumSimulator(QUANTUM_SIMULATOR handle)
Expand All @@ -49,6 +71,7 @@ void* LoadProc(QUANTUM_SIMULATOR handle, const char* procName)
return ::dlsym(handle, procName);
#endif
}
} // namespace

namespace Microsoft
{
Expand Down Expand Up @@ -78,7 +101,7 @@ namespace Quantum
return static_cast<unsigned>(pauli);
}

const QUANTUM_SIMULATOR handle;
const QUANTUM_SIMULATOR handle = 0;
unsigned simulatorId = -1;
unsigned nextQubitId = 0; // the QuantumSimulator expects contiguous ids, starting from 0

Expand Down Expand Up @@ -112,25 +135,39 @@ namespace Quantum
std::cout << "*********************" << std::endl;
}

void* GetProc(const char* name)
{
void* proc = LoadProc(this->handle, name);
if (proc == nullptr)
{
throw std::runtime_error(std::string("Failed to find '") + name + "' proc in " + FULLSTATESIMULATORLIB);
}
return proc;
}

public:
CFullstateSimulator()
: handle(LoadQuantumSimulator())
{
typedef unsigned (*TInit)();
static TInit initSimulatorInstance = reinterpret_cast<TInit>(LoadProc(this->handle, "init"));
static TInit initSimulatorInstance = reinterpret_cast<TInit>(this->GetProc("init"));

this->simulatorId = initSimulatorInstance();
}
~CFullstateSimulator()
{
typedef unsigned (*TDestroy)(unsigned);
static TDestroy destroySimulatorInstance = reinterpret_cast<TDestroy>(LoadProc(this->handle, "destroy"));

destroySimulatorInstance(this->simulatorId);

// TODO: It seems that simulator might still be doing something on background threads so attempting to
// unload it might crash.
// UnloadQuantumSimulator(this->handle);
if (this->simulatorId != -1)
{
typedef unsigned (*TDestroy)(unsigned);
static TDestroy destroySimulatorInstance =
reinterpret_cast<TDestroy>(LoadProc(this->handle, "destroy"));
assert(destroySimulatorInstance);
destroySimulatorInstance(this->simulatorId);

// TODO: It seems that simulator might still be doing something on background threads so attempting to
// unload it might crash.
// UnloadQuantumSimulator(this->handle);
}
}

IQuantumGateSet* AsQuantumGateSet() override
Expand All @@ -145,7 +182,7 @@ namespace Quantum
void GetState(TGetStateCallback callback) override
{
typedef bool (*TDump)(unsigned, TGetStateCallback);
static TDump dump = reinterpret_cast<TDump>(LoadProc(this->handle, "Dump"));
static TDump dump = reinterpret_cast<TDump>(this->GetProc("Dump"));
dump(this->simulatorId, callback);
}

Expand All @@ -157,8 +194,7 @@ namespace Quantum
Qubit AllocateQubit() override
{
typedef void (*TAllocateQubit)(unsigned, unsigned);
static TAllocateQubit allocateQubit =
reinterpret_cast<TAllocateQubit>(LoadProc(this->handle, "allocateQubit"));
static TAllocateQubit allocateQubit = reinterpret_cast<TAllocateQubit>(this->GetProc("allocateQubit"));

const unsigned id = this->nextQubitId;
allocateQubit(this->simulatorId, id);
Expand All @@ -169,23 +205,23 @@ namespace Quantum
void ReleaseQubit(Qubit q) override
{
typedef void (*TReleaseQubit)(unsigned, unsigned);
static TReleaseQubit releaseQubit = reinterpret_cast<TReleaseQubit>(LoadProc(this->handle, "release"));
static TReleaseQubit releaseQubit = reinterpret_cast<TReleaseQubit>(this->GetProc("release"));

releaseQubit(this->simulatorId, GetQubitId(q));
}

Result M(Qubit q) override
{
typedef unsigned (*TM)(unsigned, unsigned);
static TM m = reinterpret_cast<TM>(LoadProc(this->handle, "M"));
static TM m = reinterpret_cast<TM>(this->GetProc("M"));
return reinterpret_cast<Result>(m(this->simulatorId, GetQubitId(q)));
}

Result Measure(long numBases, PauliId bases[], long numTargets, Qubit targets[]) override
{
assert(numBases == numTargets);
typedef unsigned (*TMeasure)(unsigned, unsigned, unsigned*, unsigned*);
static TMeasure m = reinterpret_cast<TMeasure>(LoadProc(this->handle, "Measure"));
static TMeasure m = reinterpret_cast<TMeasure>(this->GetProc("Measure"));
vector<unsigned> ids = GetQubitIds(numTargets, targets);
return reinterpret_cast<Result>(
m(this->simulatorId, numBases, reinterpret_cast<unsigned*>(bases), ids.data()));
Expand Down Expand Up @@ -217,128 +253,122 @@ namespace Quantum

void X(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "X"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("X"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledX(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCX"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCX"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void Y(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "Y"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("Y"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledY(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCY"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCY"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void Z(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "Z"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("Z"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledZ(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCZ"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCZ"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void H(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "H"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("H"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledH(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCH"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCH"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void S(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "S"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("S"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledS(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCS"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCS"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void AdjointS(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "AdjS"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("AdjS"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledAdjointS(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCAdjS"));
reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCAdjS"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void T(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "T"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("T"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledT(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCT"));
static TSingleQubitControlledGate op = reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCT"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void AdjointT(Qubit q) override
{
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(LoadProc(this->handle, "AdjT"));
static TSingleQubitGate op = reinterpret_cast<TSingleQubitGate>(this->GetProc("AdjT"));
op(this->simulatorId, GetQubitId(q));
}

void ControlledAdjointT(long numControls, Qubit controls[], Qubit target) override
{
static TSingleQubitControlledGate op =
reinterpret_cast<TSingleQubitControlledGate>(LoadProc(this->handle, "MCAdjT"));
reinterpret_cast<TSingleQubitControlledGate>(this->GetProc("MCAdjT"));
vector<unsigned> ids = GetQubitIds(numControls, controls);
op(this->simulatorId, numControls, ids.data(), GetQubitId(target));
}

void R(PauliId axis, Qubit target, double theta) override
{
typedef unsigned (*TR)(unsigned, unsigned, double, unsigned);
static TR r = reinterpret_cast<TR>(LoadProc(this->handle, "R"));
static TR r = reinterpret_cast<TR>(this->GetProc("R"));

r(this->simulatorId, GetBasis(axis), theta, GetQubitId(target));
}

void ControlledR(long numControls, Qubit controls[], PauliId axis, Qubit target, double theta) override
{
typedef unsigned (*TMCR)(unsigned, unsigned, double, unsigned, unsigned*, unsigned);
static TMCR cr = reinterpret_cast<TMCR>(LoadProc(this->handle, "MCR"));
static TMCR cr = reinterpret_cast<TMCR>(this->GetProc("MCR"));

vector<unsigned> ids = GetQubitIds(numControls, controls);
cr(this->simulatorId, GetBasis(axis), theta, numControls, ids.data(), GetQubitId(target));
Expand All @@ -347,7 +377,7 @@ namespace Quantum
void Exp(long numTargets, PauliId paulis[], Qubit targets[], double theta) override
{
typedef unsigned (*TExp)(unsigned, unsigned, unsigned*, double, unsigned*);
static TExp exp = reinterpret_cast<TExp>(LoadProc(this->handle, "Exp"));
static TExp exp = reinterpret_cast<TExp>(this->GetProc("Exp"));
vector<unsigned> ids = GetQubitIds(numTargets, targets);
exp(this->simulatorId, numTargets, reinterpret_cast<unsigned*>(paulis), theta, ids.data());
}
Expand All @@ -361,7 +391,7 @@ namespace Quantum
double theta) override
{
typedef unsigned (*TMCExp)(unsigned, unsigned, unsigned*, double, unsigned, unsigned*, unsigned*);
static TMCExp cexp = reinterpret_cast<TMCExp>(LoadProc(this->handle, "MCExp"));
static TMCExp cexp = reinterpret_cast<TMCExp>(this->GetProc("MCExp"));
vector<unsigned> idsTargets = GetQubitIds(numTargets, targets);
vector<unsigned> idsControls = GetQubitIds(numControls, controls);
cexp(
Expand All @@ -384,8 +414,7 @@ namespace Quantum
const char* failureMessage) override
{
typedef double (*TOp)(unsigned id, unsigned n, int* b, unsigned* q);
static TOp jointEnsembleProbability =
reinterpret_cast<TOp>(LoadProc(this->handle, "JointEnsembleProbability"));
static TOp jointEnsembleProbability = reinterpret_cast<TOp>(this->GetProc("JointEnsembleProbability"));

vector<unsigned> ids = GetQubitIds(numTargets, targets);
double actualProbability =
Expand Down