zml/tensor: fix returned value in Tensor.toMemory – ensure _output_memory_kind is set correctly in the result.

This commit is contained in:
Tarry Singh 2024-09-18 13:18:08 +00:00
parent fa4a8d8de4
commit 2863c1f5e0

View File

@ -206,8 +206,6 @@ pub const Tensor = struct {
const ctx = self.getContext();
const mlir_ctx = ctx.mlirCtx();
if (ctx.target() == .cpu) return self;
var res = self;
res._output_memory_kind = kind;
const memory_kind = @tagName(kind.toPjrtMemory());
@ -223,7 +221,9 @@ pub const Tensor = struct {
.api_version = .original,
}, &.{self.value().getType()}, mlir_ctx.location(@src()));
return _result(res._shape, op.result(0));
var res = _result(self._shape, op.result(0));
res._output_memory_kind = kind;
return res;
},
.buffer_id => {
var res = self;