zml/tensor: fix returned value in Tensor.toMemory – ensure _output_memory_kind is set correctly in the result.
This commit is contained in:
parent
fa4a8d8de4
commit
2863c1f5e0
@ -206,8 +206,6 @@ pub const Tensor = struct {
|
|||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
if (ctx.target() == .cpu) return self;
|
if (ctx.target() == .cpu) return self;
|
||||||
var res = self;
|
|
||||||
res._output_memory_kind = kind;
|
|
||||||
|
|
||||||
const memory_kind = @tagName(kind.toPjrtMemory());
|
const memory_kind = @tagName(kind.toPjrtMemory());
|
||||||
|
|
||||||
@ -223,7 +221,9 @@ pub const Tensor = struct {
|
|||||||
.api_version = .original,
|
.api_version = .original,
|
||||||
}, &.{self.value().getType()}, mlir_ctx.location(@src()));
|
}, &.{self.value().getType()}, mlir_ctx.location(@src()));
|
||||||
|
|
||||||
return _result(res._shape, op.result(0));
|
var res = _result(self._shape, op.result(0));
|
||||||
|
res._output_memory_kind = kind;
|
||||||
|
return res;
|
||||||
},
|
},
|
||||||
.buffer_id => {
|
.buffer_id => {
|
||||||
var res = self;
|
var res = self;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user