Skip to content

Commit efe6054

Browse files
author
Mingsheng Hong
committed
1. Fixed SR-7913: remove the call to SILModule::linkFunction(String) … (#11)
* 1. Resolves [SR-7913](https://bugs.swift.org/browse/SR-7913): remove the call to SILModule::linkFunction(String) in TFUtilities.cpp In the new code, after calling fn = module.findFunction(..), fn->isDefinition() returns false. In the old code, the result of module.findFunction() returns true for isDefinition(). The fix is to load and link the function via module.loadFunction() and module.linkFunction() after finding the function, such that fn->isDefinition() returns true. This makes fn->getForwardingSubstitutions().size() == 1, as required in getSingletonSubstitutionFromFunction(). Also tried the alternative as suggested by Richard, but the code below crashed: ``` ty->getMemberSubstitutionMap( fn->getModule().getSwiftModule(), funcDecl, ctx.getProtocol(KnownProtocolKind::TensorSendableReceivable)->getGenericEnvironment()) ``` Here ty is `(struct_type decl=Swift.(file).Float).` funcDecl is ``` (func_decl "receiveFromDevice(_:_:)" interface type='<Scalar where Scalar : AccelerableByTensorFlow> (TensorHandle<Scalar>.Type) -> (_TensorComputation, Int) -> TensorHandle<Scalar>' access=internal final type (parameter_list (parameter "self" interface type='TensorHandle<Scalar>.Type')) (parameter_list (parameter "computation" interface type='_TensorComputation') (parameter "tensorId" interface type='Int'))) ``` crash point: > swift: /usr/local/google/home/hongm/ssd_part/swift-private/swift/lib/AST/Type.cpp:3198: swift::Type swift::TypeBase::getSuperclassForDecl(const swift::ClassDecl *): Assertion `isa<ClassDecl>(nominalDecl) && "expected a class here"' failed. 2. Resolves [SR-7915](https://bugs.swift.org/browse/SR-7915): for a string attr of an tfop inst, the new SIL for such an attr can be %3 below. ``` %0 = string_literal utf8 "foo" // user: %4 %1 = integer_literal $Builtin.Word, 3 // users: %6, %4 %2 = integer_literal $Builtin.Int1, -1 // users: %6, %4 // function_ref specialized String.init(_builtinStringLiteral:utf8CodeUnitCount:isASCII:) %3 = function_ref @$SSS21_builtinStringLiteral17utf8CodeUnitCount7isASCIISSBp_BwBi1_tcfCTf4nnnd_n : $@convention(thin) (Builtin.RawPointer, Builtin.Word, Builtin.Int1) -> @owned String // users: %6, %4 %4 = apply %3(%0, %1, %2) : $@convention(thin) (Builtin.RawPointer, Builtin.Word, Builtin.Int1) -> @owned String // user: %7 ```
1 parent 350a3f5 commit efe6054

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

Diff for: lib/SILOptimizer/Mandatory/TFUtilities.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,14 @@ unsigned tf::convertSwiftTypeToTF(Type ty) {
214214
static SILFunction *lookupOrLinkFunction(StringRef name, SILModule &module) {
215215
if (auto *localFn = module.lookUpFunction(name))
216216
return localFn;
217-
if (module.linkFunction(name))
218-
return module.findFunction(name, SILLinkage::PublicExternal);
219-
return nullptr;
217+
auto *fn = module.findFunction(name, SILLinkage::PublicExternal);
218+
assert(fn);
219+
bool loaded = module.loadFunction(fn);
220+
assert(loaded); (void)loaded;
221+
bool linked = module.linkFunction(fn);
222+
assert(linked); (void)linked;
223+
assert(fn->isDefinition());
224+
return fn;
220225
}
221226

222227
/// Looks up members by `name` in the context of `typeDecl`, `proto` and
@@ -608,6 +613,22 @@ SingleValueInstruction *SILTensorOpInfo::getAttrOperand(SILValue v) {
608613
}
609614
}
610615

616+
// In this case, the expected SIL code to match looks like:
617+
// %0 = string_literal utf8 "foo"
618+
// // function_ref specialized String.init(
619+
// _builtinStringLiteral:utf8CodeUnitCount:isASCII:)
620+
// function_ref @$SSS21_builtinStringLiteral... : $@convention(thin) (
621+
// Builtin.RawPointer...) -> @owned String
622+
// %4 = apply %3(%0, ...
623+
// So we want to follow the first func arg of the ApplyInst (%0 above).
624+
if (auto *ai = dyn_cast<ApplyInst>(str)) {
625+
// If the ApplyInst does not have such an operand, we bail with failure.
626+
if (ai->getNumOperands() < 2) return nullptr;
627+
628+
str = ai->getOperand(1);
629+
continue;
630+
}
631+
611632
// It is possible that we have a variable string, we want to reject it
612633
// as a non-constant value.
613634
return nullptr;

Diff for: test/TensorFlow/sends_recvs.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ public func test1RecvTensor() {
240240
// CHECK-NEXT: function_ref
241241
// CHECK-NEXT: [[A_HANDLE:%.*]] = apply
242242
// CHECK-NEXT: [[A_TENSOR:%.*]] = struct $Tensor<Float> ([[A_HANDLE]]
243-
// CHECK: // function_ref atariSim(_:)
243+
// CHECK: // function_ref {{.*}} atariSim(_:)
244244
// CHECK-NEXT: [[ATARI_FN:%.*]] = function_ref
245245
// CHECK-NEXT: [[B_TENSOR:%.*]] = apply [[ATARI_FN]]([[A_TENSOR]])
246246
// CHECK-NEXT: [[B_HANDLE:%.*]] = struct_extract [[B_TENSOR]] : $Tensor<Float>, #Tensor.handle

0 commit comments

Comments
 (0)