Skip to content

Commit d55e9c2

Browse files
authored
[SYCL] Do not decompose non-trivial classes with pointers (#6886)
Instead, the following code is being generated: ``` void ocl_kernel(__generated_type GT) { Kernel KernelObjClone { *(reinterpret_cast<UsersType*>(&GT)) }; } ```
1 parent 70b6767 commit d55e9c2

9 files changed

+109
-234
lines changed

clang/lib/Sema/SemaSYCL.cpp

+48-78
Original file line numberDiff line numberDiff line change
@@ -1796,20 +1796,10 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
17961796
CollectionStack.back() = true;
17971797
PointerStack.pop_back();
17981798
} else if (PointerStack.pop_back_val()) {
1799-
// FIXME: Stop triggering decomposition for non-trivial types with
1800-
// pointers
1801-
if (RD->isTrivial()) {
1802-
PointerStack.back() = true;
1803-
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
1804-
RD->addAttr(
1805-
SYCLGenerateNewTypeAttr::CreateImplicit(SemaRef.getASTContext()));
1806-
} else {
1807-
// We are visiting a non-trivial type with pointer.
1808-
CollectionStack.back() = true;
1809-
if (!RD->hasAttr<SYCLRequiresDecompositionAttr>())
1810-
RD->addAttr(SYCLRequiresDecompositionAttr::CreateImplicit(
1811-
SemaRef.getASTContext()));
1812-
}
1799+
PointerStack.back() = true;
1800+
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
1801+
RD->addAttr(
1802+
SYCLGenerateNewTypeAttr::CreateImplicit(SemaRef.getASTContext()));
18131803
}
18141804
return true;
18151805
}
@@ -2916,6 +2906,18 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29162906
Init.get());
29172907
}
29182908

2909+
void addBaseInit(const CXXBaseSpecifier &BS, QualType Ty,
2910+
InitializationKind InitKind, MultiExprArg Args) {
2911+
InitializedEntity Entity = InitializedEntity::InitializeBase(
2912+
SemaRef.Context, &BS, /*IsInheritedVirtualBase*/ false, &VarEntity);
2913+
InitializationSequence InitSeq(SemaRef, Entity, InitKind, Args);
2914+
ExprResult Init = InitSeq.Perform(SemaRef, Entity, InitKind, Args);
2915+
2916+
InitListExpr *ParentILE = CollectionInitExprs.back();
2917+
ParentILE->updateInit(SemaRef.getASTContext(), ParentILE->getNumInits(),
2918+
Init.get());
2919+
}
2920+
29192921
void addSimpleBaseInit(const CXXBaseSpecifier &BS, QualType Ty) {
29202922
InitializationKind InitKind =
29212923
InitializationKind::CreateCopy(KernelCallerSrcLoc, KernelCallerSrcLoc);
@@ -2938,85 +2940,53 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29382940
addFieldInit(FD, Ty, ParamRef);
29392941
}
29402942

2941-
Expr *addDerivedToBaseCastExpr(const CXXRecordDecl *RD,
2942-
const CXXBaseSpecifier &BS,
2943-
Expr *LocalCloneRef) {
2944-
CXXCastPath BasePath;
2945-
QualType DerivedTy(RD->getTypeForDecl(), 0);
2946-
QualType BaseTy = BS.getType();
2947-
SemaRef.CheckDerivedToBaseConversion(DerivedTy, BaseTy, KernelCallerSrcLoc,
2948-
SourceRange(), &BasePath,
2949-
/*IgnoreBaseAccess*/ true);
2950-
auto Cast = ImplicitCastExpr::Create(
2951-
SemaRef.Context, SemaRef.Context.getPointerType(BaseTy),
2952-
CK_DerivedToBase, LocalCloneRef,
2953-
/* CXXCastPath=*/&BasePath, VK_LValue, FPOptionsOverride());
2954-
return Cast;
2955-
}
2956-
29572943
Expr *createGetAddressOf(Expr *E) {
29582944
return UnaryOperator::Create(SemaRef.Context, E, UO_AddrOf,
29592945
SemaRef.Context.getPointerType(E->getType()),
29602946
VK_PRValue, OK_Ordinary, KernelCallerSrcLoc,
29612947
false, SemaRef.CurFPFeatureOverrides());
29622948
}
29632949

2964-
Expr *buildMemCpyCall(Expr *From, Expr *To, QualType T) {
2965-
// Compute the size of the memory buffer to be copied.
2966-
QualType SizeType = SemaRef.Context.getSizeType();
2967-
llvm::APInt Size(SemaRef.Context.getTypeSize(SizeType),
2968-
SemaRef.Context.getTypeSizeInChars(T).getQuantity());
2969-
2970-
LookupResult R(SemaRef, &SemaRef.Context.Idents.get("__builtin_memcpy"),
2971-
KernelCallerSrcLoc, Sema::LookupOrdinaryName);
2972-
SemaRef.LookupName(R, SemaRef.TUScope, true);
2973-
2974-
FunctionDecl *MemCpy = R.getAsSingle<FunctionDecl>();
2975-
2976-
assert(MemCpy && "__builtin_memcpy should be found");
2977-
2978-
ExprResult MemCpyRef =
2979-
SemaRef.BuildDeclRefExpr(MemCpy, SemaRef.Context.BuiltinFnTy,
2980-
VK_PRValue, KernelCallerSrcLoc, nullptr);
2981-
2982-
assert(MemCpyRef.isUsable() && "Builtin reference cannot fail");
2983-
2984-
Expr *CallArgs[] = {To, From,
2985-
IntegerLiteral::Create(SemaRef.Context, Size, SizeType,
2986-
KernelCallerSrcLoc)};
2987-
ExprResult Call =
2988-
SemaRef.BuildCallExpr(/*Scope=*/nullptr, MemCpyRef.get(),
2989-
KernelCallerSrcLoc, CallArgs, KernelCallerSrcLoc);
2950+
Expr *createDerefOp(Expr *E) {
2951+
return UnaryOperator::Create(SemaRef.Context, E, UO_Deref,
2952+
E->getType()->getPointeeType(), VK_LValue,
2953+
OK_Ordinary, KernelCallerSrcLoc, false,
2954+
SemaRef.CurFPFeatureOverrides());
2955+
}
29902956

2991-
assert(!Call.isInvalid() && "Call to __builtin_memcpy cannot fail!");
2992-
return Call.getAs<Expr>();
2957+
Expr *createReinterpretCastExpr(Expr *E, QualType To) {
2958+
return CXXReinterpretCastExpr::Create(
2959+
SemaRef.Context, To, VK_PRValue, CK_BitCast, E,
2960+
/*Path=*/nullptr, SemaRef.Context.CreateTypeSourceInfo(To),
2961+
SourceLocation(), SourceLocation(), SourceRange());
29932962
}
29942963

2995-
// Adds default initializer for generated type and creates
2996-
// a call to __builtin_memcpy to initialize local clone from
2997-
// kernel argument.
29982964
void handleGeneratedType(FieldDecl *FD, QualType Ty) {
2999-
addFieldInit(FD, Ty, None,
3000-
InitializationKind::CreateDefault(KernelCallerSrcLoc));
3001-
addFieldMemberExpr(FD, Ty);
3002-
Expr *ParamRef = createGetAddressOf(createParamReferenceExpr());
3003-
Expr *LocalCloneRef = createGetAddressOf(MemberExprBases.back());
3004-
Expr *MemCpyCallExpr = buildMemCpyCall(ParamRef, LocalCloneRef, Ty);
3005-
BodyStmts.push_back(MemCpyCallExpr);
3006-
removeFieldMemberExpr(FD, Ty);
2965+
// Equivalent of the following code is generated here:
2966+
// void ocl_kernel(__generated_type GT) {
2967+
// Kernel KernelObjClone { *(reinterpret_cast<UsersType*>(&GT)) };
2968+
// }
2969+
2970+
Expr *RCE = createReinterpretCastExpr(
2971+
createGetAddressOf(createParamReferenceExpr()),
2972+
SemaRef.Context.getPointerType(Ty));
2973+
Expr *Initializer = createDerefOp(RCE);
2974+
addFieldInit(FD, Ty, Initializer);
30072975
}
30082976

3009-
// Adds default initializer for generated base and creates
3010-
// a call to __builtin_memcpy to initialize the base of local clone
3011-
// from kernel argument.
30122977
void handleGeneratedType(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS,
30132978
QualType Ty) {
3014-
addBaseInit(BS, Ty, InitializationKind::CreateDefault(KernelCallerSrcLoc));
3015-
Expr *ParamRef = createGetAddressOf(createParamReferenceExpr());
3016-
Expr *LocalCloneRef = createGetAddressOf(MemberExprBases.back());
3017-
LocalCloneRef = addDerivedToBaseCastExpr(RD, BS, LocalCloneRef);
3018-
Expr *MemCpyCallExpr = buildMemCpyCall(ParamRef, LocalCloneRef, Ty);
3019-
BodyStmts.push_back(MemCpyCallExpr);
2979+
// Equivalent of the following code is generated here:
2980+
// void ocl_kernel(__generated_type GT) {
2981+
// Kernel KernelObjClone { *(reinterpret_cast<UsersType*>(&GT)) };
2982+
// }
2983+
Expr *RCE = createReinterpretCastExpr(
2984+
createGetAddressOf(createParamReferenceExpr()),
2985+
SemaRef.Context.getPointerType(Ty));
2986+
Expr *Initializer = createDerefOp(RCE);
2987+
InitializationKind InitKind =
2988+
InitializationKind::CreateCopy(KernelCallerSrcLoc, KernelCallerSrcLoc);
2989+
addBaseInit(BS, Ty, InitKind, Initializer);
30202990
}
30212991

30222992
MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) {

clang/test/CodeGenSYCL/inheritance.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@ int main() {
6464
// Initialize 'base' subobject
6565
// CHECK: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 8 %[[LOCAL_OBJECT]], ptr addrspace(4) align 4 %[[ARG_BASE]], i64 12, i1 false)
6666

67-
// Initialize field 'a'
68-
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, ptr addrspace(4) %[[LOCAL_OBJECT]], i32 0, i32 3
69-
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, ptr addrspace(4) %[[ARG_A]], align 4
70-
// CHECK: store i32 %[[LOAD_A]], ptr addrspace(4) %[[GEP_A]]
71-
7267
// Initialize 'second_base' subobject
7368
// First, derived-to-base cast with offset:
7469
// CHECK: %[[OFFSET_CALC:.*]] = getelementptr inbounds i8, ptr addrspace(4) %[[LOCAL_OBJECT]], i64 16
7570
// Initialize 'second_base'
7671
// CHECK: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 8 %[[OFFSET_CALC]], ptr addrspace(4) align 8 %[[ARG_BASE1]], i64 8, i1 false)
72+
73+
// Initialize field 'a'
74+
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, ptr addrspace(4) %[[LOCAL_OBJECT]], i32 0, i32 3
75+
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, ptr addrspace(4) %[[ARG_A]], align 4
76+
// CHECK: store i32 %[[LOAD_A]], ptr addrspace(4) %[[GEP_A]]
77+

clang/test/CodeGenSYCL/no_opaque_inheritance.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ int main() {
6767
// CHECK: %[[PARAM_TO_PTR:.*]] = bitcast %struct.base addrspace(4)* %[[ARG_BASE]] to i8 addrspace(4)*
6868
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[BASE_TO_PTR]], i8 addrspace(4)* align 4 %[[PARAM_TO_PTR]], i64 12, i1 false)
6969

70-
// Initialize field 'a'
71-
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, %struct.derived addrspace(4)* %[[LOCAL_OBJECT]], i32 0, i32 3
72-
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, i32 addrspace(4)* %[[ARG_A]], align 4
73-
// CHECK: store i32 %[[LOAD_A]], i32 addrspace(4)* %[[GEP_A]]
74-
7570
// Initialize 'second_base' subobject
7671
// First, derived-to-base cast with offset:
7772
// CHECK: %[[DERIVED_PTR:.*]] = bitcast %struct.derived addrspace(4)* %[[LOCAL_OBJECT]] to i8 addrspace(4)*
7873
// CHECK: %[[OFFSET_CALC:.*]] = getelementptr inbounds i8, i8 addrspace(4)* %[[DERIVED_PTR]], i64 16
7974
// CHECK: %[[TO_SECOND_BASE:.*]] = bitcast i8 addrspace(4)* %[[OFFSET_CALC]] to %class.second_base addrspace(4)*
80-
// CHECK: %[[SECOND_BASE_TO_PTR:.*]] = bitcast %class.second_base addrspace(4)* %[[TO_SECOND_BASE]] to i8 addrspace(4)*
81-
// CHECK: %[[SECOND_PARAM_TO_PTR:.*]] = bitcast %class.__generated_second_base addrspace(4)* %[[ARG_BASE1]] to i8 addrspace(4)*
82-
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[SECOND_BASE_TO_PTR]], i8 addrspace(4)* align 8 %[[SECOND_PARAM_TO_PTR]], i64 8, i1 false)
75+
// CHECK: %[[GEN_TO_SECOND_BASE:.*]] = bitcast %class.__generated_second_base addrspace(4)* %[[ARG_BASE1]] to %class.second_base addrspace(4)*
76+
// CHECK: %[[TO:.*]] = bitcast %class.second_base addrspace(4)* %[[TO_SECOND_BASE]] to i8 addrspace(4)*
77+
// CHECK: %[[FROM:.*]] = bitcast %class.second_base addrspace(4)* %[[GEN_TO_SECOND_BASE]] to i8 addrspace(4)*
78+
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[TO]], i8 addrspace(4)* align 8 %[[FROM]], i64 8, i1 false)
8379

80+
81+
// Initialize field 'a'
82+
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, %struct.derived addrspace(4)* %[[LOCAL_OBJECT]], i32 0, i32 3
83+
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, i32 addrspace(4)* %[[ARG_A]], align 4
84+
// CHECK: store i32 %[[LOAD_A]], i32 addrspace(4)* %[[GEP_A]]

clang/test/CodeGenSYCL/no_opaque_pointers-in-structs.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ int main() {
4545
// CHECK-SAME: %[[GENERATED_A]]* noundef byval(%[[GENERATED_A]]) align 8 %_arg_F3,
4646
// CHECK-SAME: %[[WRAPPER_F4_1]]* noundef byval(%[[WRAPPER_F4_1]]) align 8 %_arg_F4
4747
// CHECK-SAME: %[[WRAPPER_F4_2]]* noundef byval(%[[WRAPPER_F4_2]]) align 8 %_arg_F41
48-
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(%[[WRAPPER_LAMBDA_PTR]]* noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Ptr)
48+
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(%[[WRAPPER_LAMBDA_PTR]]* noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Lambda)

clang/test/CodeGenSYCL/pointers-in-structs.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ int main() {
4545
// CHECK-SAME: ptr noundef byval(%[[GENERATED_A]]) align 8 %_arg_F3,
4646
// CHECK-SAME: ptr noundef byval(%[[WRAPPER_F4_1]]) align 8 %_arg_F4
4747
// CHECK-SAME: ptr noundef byval(%[[WRAPPER_F4_2]]) align 8 %_arg_F41
48-
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(ptr noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Ptr)
48+
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(ptr noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Lambda)

0 commit comments

Comments
 (0)