45#include "restraints/RestraintForceModifier.hpp"
56#include "brains/ForceModifier.hpp"
58#include "restraints/MolecularRestraint.hpp"
59#include "restraints/ObjectRestraint.hpp"
60#include "selection/SelectionEvaluator.hpp"
61#include "selection/SelectionManager.hpp"
62#include "utils/Constants.hpp"
65#include "utils/simError.h"
69 RestraintForceModifier::RestraintForceModifier(SimInfo* info) :
70 ForceModifier {info} {
84 Globals* simParam = info_->getSimParams();
85 currSnapshot_ = info_->getSnapshotManager()->getCurrentSnapshot();
87 currRestTime_ = currSnapshot_->getTime();
89 if (simParam->haveStatusTime()) {
90 restTime_ = simParam->getStatusTime();
92 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
93 "Restraint warning: If you use restraints without setting\n"
94 "\tstatusTime, no restraint data will be written to the rest\n"
98 restTime_ = simParam->getRunTime();
101 int nRestraintStamps = simParam->getNRestraintStamps();
102 std::vector<RestraintStamp*> stamp = simParam->getRestraintStamps();
104 std::vector<int> stuntDoubleIndex;
106 for (
int i = 0; i < nRestraintStamps; i++) {
107 std::string myType = toUpperCopy(stamp[i]->getType());
109 if (myType.compare(
"MOLECULAR") == 0) {
113 if (!stamp[i]->haveMolIndex()) {
114 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
115 "Restraint Error: A molecular restraint was specified\n"
116 "\twithout providing a value for molIndex.\n");
117 painCave.isFatal = 1;
120 molIndex = stamp[i]->getMolIndex();
124 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
125 "Restraint Error: A molecular restraint was specified\n"
126 "\twith a molIndex that was less than 0\n");
127 painCave.isFatal = 1;
130 if (molIndex >= info_->getNGlobalMolecules()) {
131 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
132 "Restraint Error: A molecular restraint was specified with\n"
133 "\ta molIndex that was greater than the total number of "
135 painCave.isFatal = 1;
139 Molecule* mol = info_->getMoleculeByGlobalIndex(molIndex);
148 MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
150 if (info_->getMolToProc(molIndex) == myrank) {
155 painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
156 "Restraint Error: A molecular restraint was specified, but\n"
157 "\tno molecule was found with global index %d.\n",
159 painCave.isFatal = 1;
171 MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
172 if (info_->getMolToProc(molIndex) == myrank) {
175 MolecularRestraint* rest =
new MolecularRestraint();
177 std::string restPre(
"mol_");
178 std::stringstream restName;
179 restName << restPre << molIndex;
180 rest->setRestraintName(restName.str());
182 if (stamp[i]->haveDisplacementSpringConstant()) {
183 rest->setDisplacementForceConstant(
184 stamp[i]->getDisplacementSpringConstant());
186 if (stamp[i]->haveAbsoluteSpringConstant()) {
187 rest->setAbsoluteForceConstant(
188 stamp[i]->getAbsoluteSpringConstant());
190 if (stamp[i]->haveTwistSpringConstant()) {
191 rest->setTwistForceConstant(stamp[i]->getTwistSpringConstant());
193 if (stamp[i]->haveSwingXSpringConstant()) {
194 rest->setSwingXForceConstant(stamp[i]->getSwingXSpringConstant());
196 if (stamp[i]->haveSwingYSpringConstant()) {
197 rest->setSwingYForceConstant(stamp[i]->getSwingYSpringConstant());
199 if (stamp[i]->haveAbsolutePositionZ()) {
200 rest->setAbsolutePositionZ(stamp[i]->getAbsolutePositionZ());
202 if (stamp[i]->haveRestrainedTwistAngle()) {
203 rest->setRestrainedTwistAngle(stamp[i]->getRestrainedTwistAngle() *
204 Constants::PI / 180.0);
206 if (stamp[i]->haveRestrainedSwingYAngle()) {
207 rest->setRestrainedSwingYAngle(
208 stamp[i]->getRestrainedSwingYAngle() * Constants::PI / 180.0);
210 if (stamp[i]->haveRestrainedSwingXAngle()) {
211 rest->setRestrainedSwingXAngle(
212 stamp[i]->getRestrainedSwingXAngle() * Constants::PI / 180.0);
214 if (stamp[i]->havePrint()) {
215 rest->setPrintRestraint(stamp[i]->getPrint());
218 restraints_.push_back(rest);
219 mol->addProperty(std::make_shared<RestraintData>(
"Restraint", rest));
220 restrainedMols_.push_back(mol);
224 }
else if (myType.compare(
"OBJECT") == 0) {
225 std::string objectSelection;
227 if (!stamp[i]->haveObjectSelection()) {
228 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
229 "Restraint Error: An object restraint was specified\n"
230 "\twithout providing a selection script in the\n"
231 "\tobjectSelection variable.\n");
232 painCave.isFatal = 1;
235 objectSelection = stamp[i]->getObjectSelection();
238 SelectionEvaluator evaluator(info);
239 SelectionManager seleMan(info);
241 evaluator.loadScriptString(objectSelection);
242 seleMan.setSelectionSet(evaluator.evaluate());
243 int selectionCount = seleMan.getSelectionCount();
246 MPI_Allreduce(MPI_IN_PLACE, &selectionCount, 1, MPI_INT, MPI_SUM,
250 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
251 "Restraint Info: The specified restraint objectSelection,\n"
253 "\twill result in %d integrable objects being\n"
255 objectSelection.c_str(), selectionCount);
256 painCave.severity = OPENMD_INFO;
257 painCave.isFatal = 0;
263 for (sd = seleMan.beginSelected(selei); sd != NULL;
264 sd = seleMan.nextSelected(selei)) {
265 stuntDoubleIndex.push_back(sd->getGlobalIntegrableObjectIndex());
267 ObjectRestraint* rest =
new ObjectRestraint();
269 if (stamp[i]->haveDisplacementSpringConstant()) {
270 rest->setDisplacementForceConstant(
271 stamp[i]->getDisplacementSpringConstant());
273 if (stamp[i]->haveAbsoluteSpringConstant()) {
274 rest->setAbsoluteForceConstant(
275 stamp[i]->getAbsoluteSpringConstant());
277 if (stamp[i]->haveTwistSpringConstant()) {
278 rest->setTwistForceConstant(stamp[i]->getTwistSpringConstant());
280 if (stamp[i]->haveSwingXSpringConstant()) {
281 rest->setSwingXForceConstant(stamp[i]->getSwingXSpringConstant());
283 if (stamp[i]->haveSwingYSpringConstant()) {
284 rest->setSwingYForceConstant(stamp[i]->getSwingYSpringConstant());
286 if (stamp[i]->haveAbsolutePositionZ()) {
287 rest->setAbsolutePositionZ(stamp[i]->getAbsolutePositionZ());
289 if (stamp[i]->haveRestrainedTwistAngle()) {
290 rest->setRestrainedTwistAngle(stamp[i]->getRestrainedTwistAngle());
292 if (stamp[i]->haveRestrainedSwingXAngle()) {
293 rest->setRestrainedSwingXAngle(
294 stamp[i]->getRestrainedSwingXAngle());
296 if (stamp[i]->haveRestrainedSwingYAngle()) {
297 rest->setRestrainedSwingYAngle(
298 stamp[i]->getRestrainedSwingYAngle());
300 if (stamp[i]->havePrint()) {
301 rest->setPrintRestraint(stamp[i]->getPrint());
304 restraints_.push_back(rest);
305 sd->addProperty(std::make_shared<RestraintData>(
"Restraint", rest));
306 restrainedObjs_.push_back(sd);
314 if (simParam->getUseRestraints()) {
315 std::string refFile = simParam->getRestraint_file();
316 RestReader* rr =
new RestReader(info, refFile, stuntDoubleIndex);
317 rr->readReferenceStructure();
321 restOutput_ =
getPrefix(info_->getFinalConfigFileName()) +
".rest";
322 restOut =
new RestWriter(info_, restOutput_.c_str(), restraints_);
324 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
325 "Restraint error: Failed to create RestWriter\n");
326 painCave.isFatal = 1;
331 std::vector<Restraint*>::const_iterator resti;
332 for (resti = restraints_.begin(); resti != restraints_.end(); ++resti) {
333 (*resti)->setScaleFactor(1.0);
337 RestraintForceModifier::~RestraintForceModifier() {
338 Utils::deletePointers(restraints_);
343 void RestraintForceModifier::modifyForces() {
344 RealType restPot(0.0);
346 restPot = doRestraints(1.0);
349 MPI_Allreduce(MPI_IN_PLACE, &restPot, 1, MPI_REALTYPE, MPI_SUM,
353 currSnapshot_ = info_->getSnapshotManager()->getCurrentSnapshot();
354 RealType rp = currSnapshot_->getRestraintPotential();
355 currSnapshot_->setRestraintPotential(rp + restPot);
357 RealType pe = currSnapshot_->getPotentialEnergy();
358 currSnapshot_->setRawPotential(pe);
359 currSnapshot_->setPotentialEnergy(pe + restPot);
362 if (currSnapshot_->getTime() >= currRestTime_) {
363 restOut->writeRest(restInfo_);
364 currRestTime_ += restTime_;
368 RealType RestraintForceModifier::doRestraints(RealType scalingFactor) {
369 std::vector<Molecule*>::const_iterator rm;
370 std::shared_ptr<GenericData> data;
371 Molecule::IntegrableObjectIterator ioi;
372 MolecularRestraint* mRest = NULL;
373 ObjectRestraint* oRest = NULL;
376 std::vector<StuntDouble*>::const_iterator ro;
378 std::map<int, Restraint::RealPair> restInfo;
380 unscaledPotential_ = 0.0;
384 for (rm = restrainedMols_.begin(); rm != restrainedMols_.end(); ++rm) {
386 data = (*rm)->getPropertyByName(
"Restraint");
387 if (data !=
nullptr) {
389 std::shared_ptr<RestraintData> restData =
390 std::dynamic_pointer_cast<RestraintData>(data);
391 if (restData !=
nullptr) {
394 mRest =
dynamic_cast<MolecularRestraint*
>(restData->getData());
396 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
397 "Can not cast RestraintData to MolecularRestraint\n");
398 painCave.severity = OPENMD_ERROR;
399 painCave.isFatal = 1;
403 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
404 "Can not cast GenericData to RestraintData\n");
405 painCave.severity = OPENMD_ERROR;
406 painCave.isFatal = 1;
410 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
411 "Can not find Restraint for RestrainedObject\n");
412 painCave.severity = OPENMD_ERROR;
413 painCave.isFatal = 1;
420 Vector3d molCom = (*rm)->getCom();
422 std::vector<Vector3d> struc;
423 std::vector<Vector3d> forces;
425 for (sd = (*rm)->beginIntegrableObject(ioi); sd != NULL;
426 sd = (*rm)->nextIntegrableObject(ioi)) {
427 struc.push_back(sd->getPos());
430 mRest->setScaleFactor(scalingFactor);
431 mRest->calcForce(struc, molCom);
432 forces = mRest->getRestraintForces();
435 for (sd = (*rm)->beginIntegrableObject(ioi); sd != NULL;
436 sd = (*rm)->nextIntegrableObject(ioi)) {
437 sd->addFrc(forces[index]);
438 struc.push_back(sd->getPos());
442 unscaledPotential_ += mRest->getUnscaledPotential();
445 if (mRest->getPrintRestraint()) {
446 restInfo = mRest->getRestraintInfo();
447 restInfo_.push_back(restInfo);
451 for (ro = restrainedObjs_.begin(); ro != restrainedObjs_.end(); ++ro) {
453 data = (*ro)->getPropertyByName(
"Restraint");
456 std::shared_ptr<RestraintData> restData =
457 std::dynamic_pointer_cast<RestraintData>(data);
458 if (restData !=
nullptr) {
461 oRest =
dynamic_cast<ObjectRestraint*
>(restData->getData());
463 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
464 "Can not cast RestraintData to ObjectRestraint\n");
465 painCave.severity = OPENMD_ERROR;
466 painCave.isFatal = 1;
470 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
471 "Can not cast GenericData to RestraintData\n");
472 painCave.severity = OPENMD_ERROR;
473 painCave.isFatal = 1;
477 snprintf(painCave.errMsg, MAX_SIM_ERROR_MSG_LENGTH,
478 "Can not find Restraint for RestrainedObject\n");
479 painCave.severity = OPENMD_ERROR;
480 painCave.isFatal = 1;
486 oRest->setScaleFactor(scalingFactor);
488 Vector3d pos = (*ro)->getPos();
490 if ((*ro)->isDirectional()) {
494 RotMat3x3d A = (*ro)->getA();
495 oRest->calcForce(pos, A);
496 (*ro)->addFrc(oRest->getRestraintForce());
497 (*ro)->addTrq(oRest->getRestraintTorque());
502 oRest->calcForce(pos);
503 (*ro)->addFrc(oRest->getRestraintForce());
506 unscaledPotential_ += oRest->getUnscaledPotential();
509 if (oRest->getPrintRestraint()) {
510 restInfo = oRest->getRestraintInfo();
511 restInfo_.push_back(restInfo);
515 return unscaledPotential_ * scalingFactor;
This basic Periodic Table class was originally taken from the data.cpp file in OpenBabel.
std::string getPrefix(const std::string &str)