96 basis1_(basis1),basis2_(basis2)
98 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument,
"basis1 and basis2 must agree in basis type");
99 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
100 std::invalid_argument,
"basis1 and basis2 must agree in cell topology");
101 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
102 std::invalid_argument,
"basis1 and basis2 must agree in coordinate system");
104 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
105 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
108 std::ostringstream basisName;
109 basisName << basis1->getName() <<
" + " << basis2->getName();
110 name_ = basisName.str();
113 this->basisCellTopology_ = basis1->getBaseCellTopology();
114 this->basisType_ = basis1->getBasisType();
115 this->basisCoordinates_ = basis1->getCoordinateSystem();
117 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
119 int degreeLength = basis1_->getPolynomialDegreeLength();
120 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument,
"Basis1 and Basis2 must agree on polynomial degree length");
122 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis degree lookup",this->basisCardinality_,degreeLength);
124 for (
int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
126 int fieldOrdinal = fieldOrdinal1;
127 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
128 for (
int d=0; d<degreeLength; d++)
130 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
133 for (
int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
135 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
137 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
138 for (
int d=0; d<degreeLength; d++)
140 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
147 const auto & cardinality = this->basisCardinality_;
150 const ordinal_type tagSize = 4;
151 const ordinal_type posScDim = 0;
152 const ordinal_type posScOrd = 1;
153 const ordinal_type posDfOrd = 2;
155 OrdinalTypeArray1DHost tagView(
"tag view", cardinality*tagSize);
157 shards::CellTopology cellTopo = this->basisCellTopology_;
159 unsigned spaceDim = cellTopo.getDimension();
161 ordinal_type basis2Offset = basis1_->getCardinality();
163 for (
unsigned d=0; d<=spaceDim; d++)
165 unsigned subcellCount = cellTopo.getSubcellCount(d);
166 for (
unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
168 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
169 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
171 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
172 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
174 ordinal_type fieldOrdinal;
175 if (localDofID < subcellDofCount1)
178 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
183 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
185 tagView(fieldOrdinal*tagSize+0) = d;
186 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
187 tagView(fieldOrdinal*tagSize+2) = localDofID;
188 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
194 this->setOrdinalTagData(this->tagToOrdinal_,
197 this->basisCardinality_,
215 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
216 if (numScalarFamilies1 > 0)
219 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
220 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument,
"When basis1 has scalar value, basis2 must also");
221 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
222 for (
int i=0; i<numScalarFamilies1; i++)
224 scalarFamilies[i] = basisValues1.
tensorData(i);
226 for (
int i=0; i<numScalarFamilies2; i++)
228 scalarFamilies[i+numScalarFamilies1] = basisValues2.
tensorData(i);
235 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.
vectorData().
isValid(), std::invalid_argument,
"When basis1 does not have tensorData() defined, it must have a valid vectorData()");
236 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument,
"When basis1 has vector value, basis2 must also");
238 const auto & vectorData1 = basisValues1.
vectorData();
239 const auto & vectorData2 = basisValues2.
vectorData();
241 const int numFamilies1 = vectorData1.
numFamilies();
242 const int numComponents = vectorData1.numComponents();
243 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument,
"basis1 and basis2 must agree on the number of components in each vector");
244 const int numFamilies2 = vectorData2.numFamilies();
246 const int numFamilies = numFamilies1 + numFamilies2;
249 for (
int i=0; i<numFamilies1; i++)
251 for (
int j=0; j<numComponents; j++)
253 vectorComponents[i][j] = vectorData1.getComponent(i,j);
256 for (
int i=0; i<numFamilies2; i++)
258 for (
int j=0; j<numComponents; j++)
260 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
374 virtual void getValues( OutputViewType outputValues,
const PointViewType inputPoints,
375 const EOperator operatorType = OPERATOR_VALUE )
const override
377 int cardinality1 = basis1_->getCardinality();
378 int cardinality2 = basis2_->getCardinality();
380 auto range1 = std::make_pair(0,cardinality1);
381 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
382 if (outputValues.rank() == 2)
384 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
385 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
387 basis1_->getValues(outputValues1, inputPoints, operatorType);
388 basis2_->getValues(outputValues2, inputPoints, operatorType);
390 else if (outputValues.rank() == 3)
392 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
393 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
395 basis1_->getValues(outputValues1, inputPoints, operatorType);
396 basis2_->getValues(outputValues2, inputPoints, operatorType);
398 else if (outputValues.rank() == 4)
400 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
401 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
403 basis1_->getValues(outputValues1, inputPoints, operatorType);
404 basis2_->getValues(outputValues2, inputPoints, operatorType);
406 else if (outputValues.rank() == 5)
408 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
409 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
411 basis1_->getValues(outputValues1, inputPoints, operatorType);
412 basis2_->getValues(outputValues2, inputPoints, operatorType);
414 else if (outputValues.rank() == 6)
416 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
417 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
419 basis1_->getValues(outputValues1, inputPoints, operatorType);
420 basis2_->getValues(outputValues2, inputPoints, operatorType);
422 else if (outputValues.rank() == 7)
424 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
425 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
427 basis1_->getValues(outputValues1, inputPoints, operatorType);
428 basis2_->getValues(outputValues2, inputPoints, operatorType);
432 INTREPID2_TEST_FOR_EXCEPTION(
true, std::invalid_argument,
"Unsupported outputValues rank");