From 2863c1f5e0fe97f1c7d2815747d92993201cb198 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 18 Sep 2024 13:18:08 +0000 Subject: [PATCH] =?UTF-8?q?zml/tensor:=20fix=20returned=20value=20in=20Ten?= =?UTF-8?q?sor.toMemory=20=E2=80=93=20ensure=20`=5Foutput=5Fmemory=5Fkind`?= =?UTF-8?q?=20is=20set=20correctly=20in=20the=20result.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/tensor.zig | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zml/tensor.zig b/zml/tensor.zig index b248ea2..79c2f2c 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -206,8 +206,6 @@ pub const Tensor = struct { const ctx = self.getContext(); const mlir_ctx = ctx.mlirCtx(); if (ctx.target() == .cpu) return self; - var res = self; - res._output_memory_kind = kind; const memory_kind = @tagName(kind.toPjrtMemory()); @@ -223,7 +221,9 @@ pub const Tensor = struct { .api_version = .original, }, &.{self.value().getType()}, mlir_ctx.location(@src())); - return _result(res._shape, op.result(0)); + var res = _result(self._shape, op.result(0)); + res._output_memory_kind = kind; + return res; }, .buffer_id => { var res = self;