Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged

R fix #9999

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 62 additions & 65 deletions R-package/R/viz.graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#' @importFrom stringr str_trim
#' @importFrom jsonlite fromJSON
#' @importFrom DiagrammeR create_graph
#' @importFrom DiagrammeR set_global_graph_attrs
#' @importFrom DiagrammeR add_global_graph_attrs
#' @importFrom DiagrammeR create_node_df
#' @importFrom DiagrammeR create_edge_df
Expand Down Expand Up @@ -63,93 +62,91 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", type="graph", graph.wi
)
}

model_list<- fromJSON(symbol$as.json())
model_nodes<- model_list$nodes
model_nodes$id<- seq_len(nrow(model_nodes))-1
model_nodes$level<- model_nodes$ID
model_list <- fromJSON(symbol$as.json())
model_nodes <- model_list$nodes
model_nodes$id <- seq_len(nrow(model_nodes))-1
model_nodes$level <- model_nodes$ID

# extract IDs from string list
tuple_str <- function(str) vapply(str_extract_all(str, "\\d+"),
function(x) paste0(x, collapse="X"),
character(1))

### substitute op for heads
op_id<- sort(unique(model_list$heads[1,]+1))
op_null<- which(model_nodes$op=="null")
op_substitute<- intersect(op_id, op_null)
model_nodes$op[op_substitute]<- model_nodes$name[op_substitute]

model_nodes$color<- apply(model_nodes["op"], 1, get.color)
model_nodes$shape<- apply(model_nodes["op"], 1, get.shape)

label_paste <- paste0(
model_nodes$op,
"\n",
model_nodes$name,
"\n",
model_nodes$attr$num_hidden %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$pool_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
" / ",
model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
", ",
model_nodes$attr$num_filter %>% str_replace_na() %>% str_replace_all(pattern = "NA", "")
) %>%
op_id <- sort(unique(model_list$heads[1,]+1))
op_null <- which(model_nodes$op=="null")
op_substitute <- intersect(op_id, op_null)
model_nodes$op[op_substitute] <- model_nodes$name[op_substitute]

model_nodes$color <- apply(model_nodes["op"], 1, get.color)
model_nodes$shape <- apply(model_nodes["op"], 1, get.shape)

label_paste <- paste0(model_nodes$op,
"\n",
model_nodes$name,
"\n",
model_nodes$attr$num_hidden %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$pool_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
" / ",
model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""),
", ",
model_nodes$attr$num_filter %>% str_replace_na() %>% str_replace_all(pattern = "NA", "")) %>%
str_replace_all(pattern = "[^[:alnum:]]+$", "") %>%
str_trim

model_nodes$label<- label_paste
model_nodes$label <- label_paste

id.to.keep <- model_nodes$id[!model_nodes$op=="null"]
nodes_df <- model_nodes[model_nodes$id %in% id.to.keep, c("id", "label", "shape", "color")]

### remapping for DiagrammeR convention
nodes_df$id<- nodes_df$id
nodes_df$id_graph<- seq_len(nrow(nodes_df))
id_dic<- nodes_df$id_graph
names(id_dic)<- as.character(nodes_df$id)

edges_id<- model_nodes$id[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"]
edges_id<- id_dic[as.character(edges_id)]
edges<- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"]
edges<- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), simplify = FALSE)
names(edges)<- edges_id

edges_df<- data.frame(
from=unlist(edges),
to=rep(names(edges), time=lengths(edges)),
arrows = "to",
color="black",
from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"),
stringsAsFactors=FALSE)
edges_df$from<- id_dic[as.character(edges_df$from)]

nodes_df_new<- create_node_df(n = nrow(nodes_df), label=nodes_df$label, shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, style="filled",
fillcolor=adjustcolor(nodes_df$color, alpha.f = 1), fontcolor = "black")
edge_df_new<- create_edge_df(from = edges_df$from, to=edges_df$to, color="black", fontcolor = "black")

if (!is.null(shape)){
nodes_df$id <- nodes_df$id
nodes_df$id_graph <- seq_len(nrow(nodes_df))
id_dic <- nodes_df$id_graph
names(id_dic) <- as.character(nodes_df$id)

edges_id <- model_nodes$id[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"]
edges_id <- id_dic[as.character(edges_id)]
edges <- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"]
edges <- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), simplify = FALSE)
names(edges) <- edges_id

edges_df <- data.frame(from=unlist(edges),
to=rep(names(edges), time=lengths(edges)),
arrows = "to",
color="black",
from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"),
stringsAsFactors=FALSE)
edges_df$from <- id_dic[as.character(edges_df$from)]

nodes_df_new <- create_node_df(n = nrow(nodes_df), label=nodes_df$label, shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, style="filled",
fillcolor=adjustcolor(nodes_df$color, alpha.f = 1), fontcolor = "black")
edge_df_new <- create_edge_df(from = edges_df$from, to=edges_df$to, color="black", fontcolor = "black")

if (!is.null(shape)) {
if (is.list(shape)) {
edges_labels_raw<- symbol$get.internals()$infer.shape(shape)$out.shapes
} else edges_labels_raw<- symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
if (!is.null(edges_labels_raw)){
edges_labels_raw <- symbol$get.internals()$infer.shape(shape)$out.shapes
} else edges_labels_raw <- symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
if (!is.null(edges_labels_raw)) {
edge_label_str <- function(x) paste0(x, collapse="X")
edges_labels_raw<- vapply(edges_labels_raw, edge_label_str, character(1))
names(edges_labels_raw)[names(edges_labels_raw)=="data"]<- "data_output"
edge_df_new$label<- edges_labels_raw[edges_df$from_name_output]
edge_df_new$rel<- edge_df_new$label
edges_labels_raw <- vapply(edges_labels_raw, edge_label_str, character(1))
names(edges_labels_raw)[names(edges_labels_raw)=="data"] <- "data_output"
edge_df_new$label <- edges_labels_raw[edges_df$from_name_output]
edge_df_new$rel <- edge_df_new$label
}
}

graph<- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, directed = TRUE) %>%
set_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>%
graph <- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, directed = TRUE, attr_theme = NULL) %>%
add_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>%
add_global_graph_attrs("rankdir", value = direction, attr_type = "graph")

if (type=="vis"){
graph_render<- render_graph(graph = graph, output = "visNetwork", width = graph.width.px, height = graph.height.px) %>% visHierarchicalLayout(direction = direction, sortMethod = "directed")
graph_render <- render_graph(graph = graph, output = "visNetwork", width = graph.width.px, height = graph.height.px) %>%
visHierarchicalLayout(direction = direction, sortMethod = "directed")
} else {
graph_render<- render_graph(graph = graph, output = "graph", width = graph.width.px, height = graph.height.px)
graph_render <- render_graph(graph = graph, output = "graph", width = graph.width.px, height = graph.height.px)
}

return(graph_render)
Expand Down