Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd728d31

Browse files
authored
Merge pull request#1527 from rstudio/fixes
Fixes for JAX updates
2 parents0f4a4fa +616740f commitd728d31

File tree

11 files changed

+69
-15
lines changed

11 files changed

+69
-15
lines changed

‎NAMESPACE‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ S3method("[[",python.builtin.super)
1212
S3method("[[",python_builtin_super_getter)
1313
S3method(Arg,keras.src.backend.Tensor)
1414
S3method(Arg,keras.src.backend.common.keras_tensor.KerasTensor)
15+
S3method(Ops,jax._src.export.shape_poly._DimExpr)
1516
S3method(Summary,keras_shape)
1617
S3method(all,equal.numpy.ndarray)
1718
S3method(as.array,jax.Array)
@@ -40,6 +41,7 @@ S3method(as.numeric,keras.src.backend.common.variables.KerasVariable)
4041
S3method(base::all.equal,keras.src.backend.Tensor)
4142
S3method(base::all.equal,keras.src.backend.common.keras_tensor.KerasTensor)
4243
S3method(base::all.equal,keras.src.backend.common.variables.KerasVariable)
44+
S3method(base::as.array,PIL.Image.Image)
4345
S3method(compile,keras.src.models.model.Model)
4446
S3method(destructure,keras_shape)
4547
S3method(evaluate,keras.src.models.model.Model)

‎NEWS.md‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
- Added elastic deformation utilities for images:`layer_random_elastic_transform()`
2424
and the lower-level`op_image_elastic_transform()`.
2525

26+
- Added`as.array()` support for`PIL.Image.Image` objects.
27+
2628
- Transposed convolution utilities now follow the latest Keras API:
2729
`op_conv_transpose()` defaults`strides = 1` and the`layer_conv_*_transpose()`
2830
layers expose`output_padding` for precise shape control.
@@ -38,6 +40,11 @@
3840

3941
-`layer_layer_normalization()` removes the`rms_scaling` argument.
4042

43+
- Merging layers now capture`...` with tidy dots (fixes#1525).
44+
45+
- Fixed Ops on JAX`_DimExpr` so symbolic shapes survive arithmetic with R
46+
double scalars.
47+
4148
-`layer_reshape()` can now accept`-1` as a sentinel for an automatically calculated axis size.
4249

4350
-`layer_torch_module_wrapper()` gains an`output_shape` argument to help Keras

‎R/jax-methods.R‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,20 @@ type_sum.keras.src.backend.jax.core.JaxVariable <- type_sum.keras.src.backend.ja
9494

9595
# "keras.src.backend.Variable" too?
9696
# "keras.src.backend.common.variables.Variable" too?
97+
98+
#' @exportS3Method Ops jax._src.export.shape_poly._DimExpr
99+
Ops.jax._src.export.shape_poly._DimExpr<-function(e1,e2) {
100+
if (missing(e2)) {
101+
return(e1)
102+
}
103+
conv<-function(x) {
104+
if (is.double(x)&& isTRUE(all(x== suppressWarnings(as.integer(x))))) {
105+
storage.mode(x)<-"integer"
106+
}
107+
x
108+
}
109+
e1<- conv(e1)
110+
e2<- conv(e2)
111+
NextMethod()
112+
}
113+

‎R/layers-merging.R‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function (inputs, ...)
5252
args<- capture_args(list(input_shape=normalize_shape,
5353
batch_size=as_integer,batch_input_shape=normalize_shape),
5454
ignore= c("...","inputs"))
55-
dots<- split_dots_named_unnamed(list(...))
55+
dots<- split_dots_named_unnamed(list2(...))
5656
if (missing(inputs))
5757
inputs<-NULL
5858
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -116,7 +116,7 @@ function (inputs, ...)
116116
args<- capture_args(list(input_shape=normalize_shape,
117117
batch_size=as_integer,batch_input_shape=normalize_shape),
118118
ignore= c("...","inputs"))
119-
dots<- split_dots_named_unnamed(list(...))
119+
dots<- split_dots_named_unnamed(list2(...))
120120
if (missing(inputs))
121121
inputs<-NULL
122122
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -177,7 +177,7 @@ function (inputs, ..., axis = -1L)
177177
args<- capture_args(list(axis=as_axis,input_shape=normalize_shape,
178178
batch_size=as_integer,batch_input_shape=normalize_shape),
179179
ignore= c("...","inputs"))
180-
dots<- split_dots_named_unnamed(list(...))
180+
dots<- split_dots_named_unnamed(list2(...))
181181
if (missing(inputs))
182182
inputs<-NULL
183183
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -260,7 +260,7 @@ function (inputs, ..., axes, normalize = FALSE)
260260
args<- capture_args(list(axes=as_axis,input_shape=normalize_shape,
261261
batch_size=as_integer,batch_input_shape=normalize_shape),
262262
ignore= c("...","inputs"))
263-
dots<- split_dots_named_unnamed(list(...))
263+
dots<- split_dots_named_unnamed(list2(...))
264264
if (missing(inputs))
265265
inputs<-NULL
266266
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -322,7 +322,7 @@ function (inputs, ...)
322322
args<- capture_args(list(input_shape=normalize_shape,
323323
batch_size=as_integer,batch_input_shape=normalize_shape),
324324
ignore= c("...","inputs"))
325-
dots<- split_dots_named_unnamed(list(...))
325+
dots<- split_dots_named_unnamed(list2(...))
326326
if (missing(inputs))
327327
inputs<-NULL
328328
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -384,7 +384,7 @@ function (inputs, ...)
384384
args<- capture_args(list(input_shape=normalize_shape,
385385
batch_size=as_integer,batch_input_shape=normalize_shape),
386386
ignore= c("...","inputs"))
387-
dots<- split_dots_named_unnamed(list(...))
387+
dots<- split_dots_named_unnamed(list2(...))
388388
if (missing(inputs))
389389
inputs<-NULL
390390
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -446,7 +446,7 @@ function (inputs, ...)
446446
args<- capture_args(list(input_shape=normalize_shape,
447447
batch_size=as_integer,batch_input_shape=normalize_shape),
448448
ignore= c("...","inputs"))
449-
dots<- split_dots_named_unnamed(list(...))
449+
dots<- split_dots_named_unnamed(list2(...))
450450
if (missing(inputs))
451451
inputs<-NULL
452452
elseif (!is.null(inputs)&&!is.list(inputs))
@@ -509,7 +509,7 @@ function (inputs, ...)
509509
args<- capture_args(list(input_shape=normalize_shape,
510510
batch_size=as_integer,batch_input_shape=normalize_shape),
511511
ignore= c("...","inputs"))
512-
dots<- split_dots_named_unnamed(list(...))
512+
dots<- split_dots_named_unnamed(list2(...))
513513
if (missing(inputs))
514514
inputs<-NULL
515515
elseif (!is.null(inputs)&&!is.list(inputs))

‎R/r-utils.R‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ split_dots_named_unnamed <- function(dots) {
2727
if (is.null(nms))
2828
return(list(unnamed=dots,named=list()))
2929
named<- nzchar(nms)
30-
list(unnamed=dots[!named],named=dots[named])
30+
unnamed<-dots[!named]
31+
names(unnamed)<-NULL
32+
list(unnamed=unnamed,named=dots[named])
3133
}
3234

3335
drop_nulls<-function(x,i=NULL) {

‎R/s3-methods.R‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,7 @@ py_to_r__keras.src.utils.tracking.TrackedSet <- function(x) import("builtins")$l
110110
# }
111111
# rm(list = c("generic", "cls"))
112112

113+
#' @exportS3Method base::as.array
114+
as.array.PIL.Image.Image<-function(x,...) {
115+
as.array(image_to_array(x,...))
116+
}

‎man/layer_discretization.Rd‎

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎man/layer_tfsm.Rd‎

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎man/op_angle.Rd‎

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
test_that("DimExpr Ops keeps symbolic dims when R uses double scalars", {
2+
skip_if_not(reticulate::py_module_available("jax"))
3+
4+
export<-reticulate::import("jax.export",convert=FALSE)
5+
dim<-export$symbolic_shape("n")[[1]]
6+
7+
expr<-dim-1# 1 is a double in R; Ops method should coerce to int
8+
9+
expect_s3_class(expr,"jax._src.export.shape_poly._DimExpr")
10+
expect_match(reticulate::py_str(expr),"n - 1")
11+
expect_false(any(grepl("Array", class(expr))))
12+
})

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp